Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 227 additions & 36 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,156 @@
#define MY_GET
#include "../env_binding.h"

static int clipped_debug_goal_count(Drive *env, Agent *agent) {
assert(env != NULL);
assert(agent != NULL);
int goal_count = get_agent_goal_count(env, agent);
if (goal_count < 0) {
return 0;
}
if (goal_count > MAX_TARGET_WAYPOINTS) {
return MAX_TARGET_WAYPOINTS;
}
return goal_count;
}

static PyObject *build_goal_lane_ids(Agent *agent, int goal_count) {
assert(agent != NULL);
assert(goal_count >= 0);
PyObject *lst = PyList_New(goal_count);
if (lst == NULL) {
return NULL;
}
for (int i = 0; i < goal_count; i++) {
PyObject *v = PyLong_FromLong(agent->goal_lane_ids[i]);
if (v == NULL || PyList_SetItem(lst, i, v) < 0) {
Py_XDECREF(v);
Py_DECREF(lst);
return NULL;
}
}
return lst;
}

static PyObject *build_goal_lane_s(Agent *agent, int goal_count) {
assert(agent != NULL);
assert(goal_count >= 0);
PyObject *lst = PyList_New(goal_count);
if (lst == NULL) {
return NULL;
}
for (int i = 0; i < goal_count; i++) {
PyObject *v = PyFloat_FromDouble((double) agent->goal_lane_s[i]);
if (v == NULL || PyList_SetItem(lst, i, v) < 0) {
Py_XDECREF(v);
Py_DECREF(lst);
return NULL;
}
}
return lst;
}

static PyObject *build_goal_positions(Agent *agent, int goal_count) {
assert(agent != NULL);
assert(goal_count >= 0);
PyObject *lst = PyList_New(goal_count);
if (lst == NULL) {
return NULL;
}
for (int i = 0; i < goal_count; i++) {
PyObject *pos = Py_BuildValue(
"(ddd)",
(double) agent->goal_positions_x[i],
(double) agent->goal_positions_y[i],
(double) agent->goal_positions_z[i]);
if (pos == NULL || PyList_SetItem(lst, i, pos) < 0) {
Py_XDECREF(pos);
Py_DECREF(lst);
return NULL;
}
}
return lst;
}

static PyObject *build_goal_route_distances(Drive *env, Agent *agent, int goal_count) {
assert(env != NULL);
assert(agent != NULL);
PyObject *lst = PyList_New(goal_count);
if (lst == NULL) {
return NULL;
}
int from_lane_idx = agent->route_length > 0 ? agent->route[0] : agent->current_lane_idx;
float from_s = 0.0f;
if (from_lane_idx >= 0 && from_lane_idx < env->num_road_elements) {
from_s = lane_s_at_position(&env->road_elements[from_lane_idx], agent->sim_x, agent->sim_y);
}
for (int i = 0; i < goal_count; i++) {
float distance = route_distance_between_lane_positions(
env,
from_lane_idx,
from_s,
agent->goal_lane_ids[i],
agent->goal_lane_s[i]);
PyObject *v = PyFloat_FromDouble((double) distance);
if (v == NULL || PyList_SetItem(lst, i, v) < 0) {
Py_XDECREF(v);
Py_DECREF(lst);
return NULL;
}
from_lane_idx = agent->goal_lane_ids[i];
from_s = agent->goal_lane_s[i];
}
return lst;
}

static int set_agent_goal_debug_fields(PyObject *agent_dict, Drive *env, Agent *agent) {
assert(agent_dict != NULL);
assert(env != NULL);
int goal_count = clipped_debug_goal_count(env, agent);
PyObject *v = PyLong_FromLong(goal_count);
if (v == NULL || PyDict_SetItemString(agent_dict, "num_active_goals", v) < 0) {
Py_XDECREF(v);
return 1;
}
Py_DECREF(v);

v = PyLong_FromLong(agent->current_goal_idx);
if (v == NULL || PyDict_SetItemString(agent_dict, "current_goal_idx", v) < 0) {
Py_XDECREF(v);
return 1;
}
Py_DECREF(v);

v = build_goal_lane_ids(agent, goal_count);
if (v == NULL || PyDict_SetItemString(agent_dict, "goal_lane_ids", v) < 0) {
Py_XDECREF(v);
return 1;
}
Py_DECREF(v);

v = build_goal_lane_s(agent, goal_count);
if (v == NULL || PyDict_SetItemString(agent_dict, "goal_lane_s", v) < 0) {
Py_XDECREF(v);
return 1;
}
Py_DECREF(v);

v = build_goal_route_distances(env, agent, goal_count);
if (v == NULL || PyDict_SetItemString(agent_dict, "goal_route_distances", v) < 0) {
Py_XDECREF(v);
return 1;
}
Py_DECREF(v);

v = build_goal_positions(agent, goal_count);
if (v == NULL || PyDict_SetItemString(agent_dict, "goal_positions", v) < 0) {
Py_XDECREF(v);
return 1;
}
Py_DECREF(v);
return 0;
}

static int my_put(Env *env, PyObject *args, PyObject *kwargs) {
PyObject *obs = PyDict_GetItemString(kwargs, "observations");
if (!PyObject_TypeCheck(obs, &PyArray_Type)) {
Expand Down Expand Up @@ -727,6 +877,12 @@ static PyObject *my_get(PyObject *dict, Env *env) {
}
Py_DECREF(pf);

if (set_agent_goal_debug_fields(agent, env, a)) {
Py_DECREF(agent);
Py_DECREF(agents_list);
return NULL;
}

/* Status flags */
tmp = PyLong_FromLong(a->stopped);
if (!tmp) {
Expand Down Expand Up @@ -922,12 +1078,10 @@ static PyObject *my_get(PyObject *dict, Env *env) {
return NULL;
}
Py_DECREF(route_list);
} else {
if (PyDict_SetItemString(agent, "route", Py_None) < 0) {
Py_DECREF(agent);
Py_DECREF(agents_list);
return NULL;
}
} else if (PyDict_SetItemString(agent, "route", Py_None) < 0) {
Py_DECREF(agent);
Py_DECREF(agents_list);
return NULL;
}

PyList_SetItem(agents_list, i, agent);
Expand Down Expand Up @@ -1146,12 +1300,10 @@ static PyObject *my_get(PyObject *dict, Env *env) {
return NULL;
}
Py_DECREF(lx);
} else {
if (PyDict_SetItemString(road, "x", Py_None) < 0) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
} else if (PyDict_SetItemString(road, "x", Py_None) < 0) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
if (r->y && seg_len > 0) {
PyObject *ly = PyList_New(seg_len);
Expand All @@ -1177,12 +1329,10 @@ static PyObject *my_get(PyObject *dict, Env *env) {
return NULL;
}
Py_DECREF(ly);
} else {
if (PyDict_SetItemString(road, "y", Py_None) < 0) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
} else if (PyDict_SetItemString(road, "y", Py_None) < 0) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
if (r->z && seg_len > 0) {
PyObject *lz = PyList_New(seg_len);
Expand All @@ -1208,12 +1358,10 @@ static PyObject *my_get(PyObject *dict, Env *env) {
return NULL;
}
Py_DECREF(lz);
} else {
if (PyDict_SetItemString(road, "z", Py_None) < 0) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
} else if (PyDict_SetItemString(road, "z", Py_None) < 0) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}

/* Lane-specific fields */
Expand Down Expand Up @@ -1277,6 +1425,50 @@ static PyObject *my_get(PyObject *dict, Env *env) {
}
Py_DECREF(pf);

pf = PyFloat_FromDouble((double) r->length);
if (!pf) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
if (PyDict_SetItemString(road, "length", pf) < 0) {
Py_DECREF(pf);
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
Py_DECREF(pf);

if (is_road_lane(r->type) && r->cum_lengths != NULL && seg_len > 0) {
tmp = PyList_New(seg_len);
if (!tmp) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
for (int k = 0; k < seg_len; k++) {
PyObject *fv = PyFloat_FromDouble((double) r->cum_lengths[k]);
if (!fv) {
Py_DECREF(tmp);
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
PyList_SET_ITEM(tmp, k, fv);
}
if (PyDict_SetItemString(road, "cum_lengths", tmp) < 0) {
Py_DECREF(tmp);
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}
Py_DECREF(tmp);
} else if (PyDict_SetItemString(road, "cum_lengths", Py_None) < 0) {
Py_DECREF(road);
Py_DECREF(road_list);
return NULL;
}

PyList_SetItem(road_list, i, road);
}
if (PyDict_SetItemString(dict, "road_elements", road_list) < 0) {
Expand Down Expand Up @@ -1373,12 +1565,10 @@ static PyObject *my_get(PyObject *dict, Env *env) {
return NULL;
}
Py_DECREF(ls);
} else {
if (PyDict_SetItemString(traffic, "states", Py_None) < 0) {
Py_DECREF(traffic);
Py_DECREF(traffic_list);
return NULL;
}
} else if (PyDict_SetItemString(traffic, "states", Py_None) < 0) {
Py_DECREF(traffic);
Py_DECREF(traffic_list);
return NULL;
}

/* Stop line endpoints */
Expand Down Expand Up @@ -1455,12 +1645,10 @@ static PyObject *my_get(PyObject *dict, Env *env) {
return NULL;
}
Py_DECREF(ll);
} else {
if (PyDict_SetItemString(traffic, "controlled_lanes", Py_None) < 0) {
Py_DECREF(traffic);
Py_DECREF(traffic_list);
return NULL;
}
} else if (PyDict_SetItemString(traffic, "controlled_lanes", Py_None) < 0) {
Py_DECREF(traffic);
Py_DECREF(traffic_list);
return NULL;
}

PyList_SetItem(traffic_list, i, traffic);
Expand Down Expand Up @@ -1807,6 +1995,9 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
env->num_target_waypoints = MAX_TARGET_WAYPOINTS;
}
env->target_type = (int) unpack(kwargs, "target_type");
if (env->target_type == TARGET_DIJKSTRA) {
env->num_target_waypoints = DIJKSTRA_TARGET_SLOTS;
}
env->max_boundary_segment_observations = (int) unpack(kwargs, "max_boundary_segment_observations");
env->max_lane_segment_observations = (int) unpack(kwargs, "max_lane_segment_observations");
env->max_partner_observations = (int) unpack(kwargs, "max_partner_observations");
Expand Down
10 changes: 7 additions & 3 deletions pufferlib/ocean/drive/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ struct Agent {
float goal_position_x; // alias = goal_positions_x[current_goal_idx]
float goal_position_y; // alias = goal_positions_y[current_goal_idx]
float goal_position_z; // alias = goal_positions_z[current_goal_idx]
int current_goal_idx; // index of next goal to reach (0..N-1)
int goal_lane_ids[MAX_TARGET_WAYPOINTS];
float goal_lane_s[MAX_TARGET_WAYPOINTS];
int num_active_goals;
int current_goal_idx; // index of next goal to reach (0..N-1)

int stopped; // 0/1 -> freeze if set
int removed; // 0/1 -> remove from sim if set
Expand Down Expand Up @@ -280,6 +283,8 @@ struct RoadMapElement {
int num_exits;
int *exit_lanes;
float speed_limit;
float length;
float *cum_lengths;
};

struct TrafficControlElement {
Expand All @@ -302,7 +307,6 @@ typedef struct {
struct LaneGraph {
int n_lanes;
int *lane_ids;
float *lane_lengths;
float *distances; // n_lanes * n_lanes row-major
};

Expand All @@ -328,6 +332,7 @@ void free_road_element(struct RoadMapElement *element) {
free(element->headings);
free(element->entry_lanes);
free(element->exit_lanes);
free(element->cum_lengths);
}

void free_traffic_element(struct TrafficControlElement *element) {
Expand All @@ -337,6 +342,5 @@ void free_traffic_element(struct TrafficControlElement *element) {

void free_lane_graph(struct LaneGraph *graph) {
free(graph->lane_ids);
free(graph->lane_lengths);
free(graph->distances);
}
Loading
Loading