diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index f67243a32d..cd1bce98cf 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -5,6 +5,156 @@ #define MY_GET #include "../env_binding.h" +static int clipped_debug_goal_count(Drive *env, Agent *agent) { + assert(env != NULL); + assert(agent != NULL); + int goal_count = get_agent_goal_count(env, agent); + if (goal_count < 0) { + return 0; + } + if (goal_count > MAX_TARGET_WAYPOINTS) { + return MAX_TARGET_WAYPOINTS; + } + return goal_count; +} + +static PyObject *build_goal_lane_ids(Agent *agent, int goal_count) { + assert(agent != NULL); + assert(goal_count >= 0); + PyObject *lst = PyList_New(goal_count); + if (lst == NULL) { + return NULL; + } + for (int i = 0; i < goal_count; i++) { + PyObject *v = PyLong_FromLong(agent->goal_lane_ids[i]); + if (v == NULL || PyList_SetItem(lst, i, v) < 0) { + Py_XDECREF(v); + Py_DECREF(lst); + return NULL; + } + } + return lst; +} + +static PyObject *build_goal_lane_s(Agent *agent, int goal_count) { + assert(agent != NULL); + assert(goal_count >= 0); + PyObject *lst = PyList_New(goal_count); + if (lst == NULL) { + return NULL; + } + for (int i = 0; i < goal_count; i++) { + PyObject *v = PyFloat_FromDouble((double) agent->goal_lane_s[i]); + if (v == NULL || PyList_SetItem(lst, i, v) < 0) { + Py_XDECREF(v); + Py_DECREF(lst); + return NULL; + } + } + return lst; +} + +static PyObject *build_goal_positions(Agent *agent, int goal_count) { + assert(agent != NULL); + assert(goal_count >= 0); + PyObject *lst = PyList_New(goal_count); + if (lst == NULL) { + return NULL; + } + for (int i = 0; i < goal_count; i++) { + PyObject *pos = Py_BuildValue( + "(ddd)", + (double) agent->goal_positions_x[i], + (double) agent->goal_positions_y[i], + (double) agent->goal_positions_z[i]); + if (pos == NULL || PyList_SetItem(lst, i, pos) < 0) { + Py_XDECREF(pos); + Py_DECREF(lst); + return NULL; + } + } + return lst; +} + +static PyObject *build_goal_route_distances(Drive *env, Agent *agent, int goal_count) { + assert(env != NULL); + assert(agent != NULL); + PyObject *lst = PyList_New(goal_count); + if (lst == NULL) { + return NULL; + } + int from_lane_idx = agent->route_length > 0 ? agent->route[0] : agent->current_lane_idx; + float from_s = 0.0f; + if (from_lane_idx >= 0 && from_lane_idx < env->num_road_elements) { + from_s = lane_s_at_position(&env->road_elements[from_lane_idx], agent->sim_x, agent->sim_y); + } + for (int i = 0; i < goal_count; i++) { + float distance = route_distance_between_lane_positions( + env, + from_lane_idx, + from_s, + agent->goal_lane_ids[i], + agent->goal_lane_s[i]); + PyObject *v = PyFloat_FromDouble((double) distance); + if (v == NULL || PyList_SetItem(lst, i, v) < 0) { + Py_XDECREF(v); + Py_DECREF(lst); + return NULL; + } + from_lane_idx = agent->goal_lane_ids[i]; + from_s = agent->goal_lane_s[i]; + } + return lst; +} + +static int set_agent_goal_debug_fields(PyObject *agent_dict, Drive *env, Agent *agent) { + assert(agent_dict != NULL); + assert(env != NULL); + int goal_count = clipped_debug_goal_count(env, agent); + PyObject *v = PyLong_FromLong(goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "num_active_goals", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = PyLong_FromLong(agent->current_goal_idx); + if (v == NULL || PyDict_SetItemString(agent_dict, "current_goal_idx", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = build_goal_lane_ids(agent, goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "goal_lane_ids", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = build_goal_lane_s(agent, goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "goal_lane_s", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = build_goal_route_distances(env, agent, goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "goal_route_distances", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = build_goal_positions(agent, goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "goal_positions", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + return 0; +} + static int my_put(Env *env, PyObject *args, PyObject *kwargs) { PyObject *obs = PyDict_GetItemString(kwargs, "observations"); if (!PyObject_TypeCheck(obs, &PyArray_Type)) { @@ -727,6 +877,12 @@ static PyObject *my_get(PyObject *dict, Env *env) { } Py_DECREF(pf); + if (set_agent_goal_debug_fields(agent, env, a)) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + /* Status flags */ tmp = PyLong_FromLong(a->stopped); if (!tmp) { @@ -922,12 +1078,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(route_list); - } else { - if (PyDict_SetItemString(agent, "route", Py_None) < 0) { - Py_DECREF(agent); - Py_DECREF(agents_list); - return NULL; - } + } else if (PyDict_SetItemString(agent, "route", Py_None) < 0) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; } PyList_SetItem(agents_list, i, agent); @@ -1146,12 +1300,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(lx); - } else { - if (PyDict_SetItemString(road, "x", Py_None) < 0) { - Py_DECREF(road); - Py_DECREF(road_list); - return NULL; - } + } else if (PyDict_SetItemString(road, "x", Py_None) < 0) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; } if (r->y && seg_len > 0) { PyObject *ly = PyList_New(seg_len); @@ -1177,12 +1329,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(ly); - } else { - if (PyDict_SetItemString(road, "y", Py_None) < 0) { - Py_DECREF(road); - Py_DECREF(road_list); - return NULL; - } + } else if (PyDict_SetItemString(road, "y", Py_None) < 0) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; } if (r->z && seg_len > 0) { PyObject *lz = PyList_New(seg_len); @@ -1208,12 +1358,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(lz); - } else { - if (PyDict_SetItemString(road, "z", Py_None) < 0) { - Py_DECREF(road); - Py_DECREF(road_list); - return NULL; - } + } else if (PyDict_SetItemString(road, "z", Py_None) < 0) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; } /* Lane-specific fields */ @@ -1277,6 +1425,50 @@ static PyObject *my_get(PyObject *dict, Env *env) { } Py_DECREF(pf); + pf = PyFloat_FromDouble((double) r->length); + if (!pf) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + if (PyDict_SetItemString(road, "length", pf) < 0) { + Py_DECREF(pf); + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + Py_DECREF(pf); + + if (is_road_lane(r->type) && r->cum_lengths != NULL && seg_len > 0) { + tmp = PyList_New(seg_len); + if (!tmp) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + for (int k = 0; k < seg_len; k++) { + PyObject *fv = PyFloat_FromDouble((double) r->cum_lengths[k]); + if (!fv) { + Py_DECREF(tmp); + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + PyList_SET_ITEM(tmp, k, fv); + } + if (PyDict_SetItemString(road, "cum_lengths", tmp) < 0) { + Py_DECREF(tmp); + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + Py_DECREF(tmp); + } else if (PyDict_SetItemString(road, "cum_lengths", Py_None) < 0) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + PyList_SetItem(road_list, i, road); } if (PyDict_SetItemString(dict, "road_elements", road_list) < 0) { @@ -1373,12 +1565,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(ls); - } else { - if (PyDict_SetItemString(traffic, "states", Py_None) < 0) { - Py_DECREF(traffic); - Py_DECREF(traffic_list); - return NULL; - } + } else if (PyDict_SetItemString(traffic, "states", Py_None) < 0) { + Py_DECREF(traffic); + Py_DECREF(traffic_list); + return NULL; } /* Stop line endpoints */ @@ -1455,12 +1645,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(ll); - } else { - if (PyDict_SetItemString(traffic, "controlled_lanes", Py_None) < 0) { - Py_DECREF(traffic); - Py_DECREF(traffic_list); - return NULL; - } + } else if (PyDict_SetItemString(traffic, "controlled_lanes", Py_None) < 0) { + Py_DECREF(traffic); + Py_DECREF(traffic_list); + return NULL; } PyList_SetItem(traffic_list, i, traffic); @@ -1807,6 +1995,9 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->num_target_waypoints = MAX_TARGET_WAYPOINTS; } env->target_type = (int) unpack(kwargs, "target_type"); + if (env->target_type == TARGET_DIJKSTRA) { + env->num_target_waypoints = DIJKSTRA_TARGET_SLOTS; + } env->max_boundary_segment_observations = (int) unpack(kwargs, "max_boundary_segment_observations"); env->max_lane_segment_observations = (int) unpack(kwargs, "max_lane_segment_observations"); env->max_partner_observations = (int) unpack(kwargs, "max_partner_observations"); diff --git a/pufferlib/ocean/drive/datatypes.h b/pufferlib/ocean/drive/datatypes.h index 2a391fb744..077f847789 100644 --- a/pufferlib/ocean/drive/datatypes.h +++ b/pufferlib/ocean/drive/datatypes.h @@ -236,7 +236,10 @@ struct Agent { float goal_position_x; // alias = goal_positions_x[current_goal_idx] float goal_position_y; // alias = goal_positions_y[current_goal_idx] float goal_position_z; // alias = goal_positions_z[current_goal_idx] - int current_goal_idx; // index of next goal to reach (0..N-1) + int goal_lane_ids[MAX_TARGET_WAYPOINTS]; + float goal_lane_s[MAX_TARGET_WAYPOINTS]; + int num_active_goals; + int current_goal_idx; // index of next goal to reach (0..N-1) int stopped; // 0/1 -> freeze if set int removed; // 0/1 -> remove from sim if set @@ -280,6 +283,8 @@ struct RoadMapElement { int num_exits; int *exit_lanes; float speed_limit; + float length; + float *cum_lengths; }; struct TrafficControlElement { @@ -302,7 +307,6 @@ typedef struct { struct LaneGraph { int n_lanes; int *lane_ids; - float *lane_lengths; float *distances; // n_lanes * n_lanes row-major }; @@ -328,6 +332,7 @@ void free_road_element(struct RoadMapElement *element) { free(element->headings); free(element->entry_lanes); free(element->exit_lanes); + free(element->cum_lengths); } void free_traffic_element(struct TrafficControlElement *element) { @@ -337,6 +342,5 @@ void free_traffic_element(struct TrafficControlElement *element) { void free_lane_graph(struct LaneGraph *graph) { free(graph->lane_ids); - free(graph->lane_lengths); free(graph->distances); } diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 329e1bdf76..19f88b07ed 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -120,11 +120,12 @@ // TARGET_TYPE modes (controls what target info is in observations) #define TARGET_STATIC 0 #define TARGET_DYNAMIC 1 +#define TARGET_DIJKSTRA 2 // Observation feature counts #define EGO_FEATURES_CLASSIC 8 #define EGO_FEATURES_JERK 10 -#define ROAD_FEATURES 7 +#define ROAD_FEATURES 9 #define PARTNER_FEATURES 8 #define TRAFFIC_CONTROL_FEATURES 7 #define PADDED_OBSERVATION_VALUE -0.001f @@ -133,6 +134,10 @@ // GIGAFLOW specific #define MAX_ROUTE_LENGTH 64 +#define DIJKSTRA_TARGET_SLOTS 4 +#define DIJKSTRA_MIN_GOAL_DISTANCE 20.0f +#define DIJKSTRA_MAX_GOAL_DISTANCE 200.0f +#define DIJKSTRA_MAX_ROUTE_ATTEMPTS 64 // Traffic light generation #define TL_DEFAULT_RED_DURATION 2.0f #define TL_DEFAULT_YELLOW_DURATION 3.0f @@ -314,6 +319,7 @@ struct Drive { int num_traffic_elements; int num_objects; struct LaneGraph lane_graph; + int *lane_graph_index_by_lane; int static_agent_count; int *static_agent_indices; int expert_static_agent_count; @@ -495,6 +501,27 @@ static int traffic_control_in_scope(int type, int scope) { } } +static void build_lane_graph_index(Drive *env) { + free(env->lane_graph_index_by_lane); + env->lane_graph_index_by_lane = NULL; + if (env->num_road_elements <= 0 || env->lane_graph.n_lanes <= 0 || env->lane_graph.lane_ids == NULL) { + return; + } + env->lane_graph_index_by_lane = (int *) malloc(env->num_road_elements * sizeof(int)); + if (env->lane_graph_index_by_lane == NULL) { + return; + } + for (int i = 0; i < env->num_road_elements; i++) { + env->lane_graph_index_by_lane[i] = -1; + } + for (int i = 0; i < env->lane_graph.n_lanes; i++) { + int lane_idx = env->lane_graph.lane_ids[i]; + if (lane_idx >= 0 && lane_idx < env->num_road_elements) { + env->lane_graph_index_by_lane[lane_idx] = i; + } + } +} + static void reset_agent_metrics(Drive *env, int agent_idx) { Agent *agent = &env->agents[agent_idx]; for (int i = 0; i < NUM_METRICS; i++) { @@ -1134,12 +1161,23 @@ int load_map_binary(const char *filename, Drive *drive) { fclose(file); return -1; } + if (fread(&road->length, sizeof(float), 1, file) != 1) { + fclose(file); + return -1; + } + road->cum_lengths = (float *) malloc(slen * sizeof(float)); + if ((size_t) slen > 0 && fread(road->cum_lengths, sizeof(float), slen, file) != (size_t) slen) { + fclose(file); + return -1; + } } else { road->num_entries = 0; road->num_exits = 0; road->entry_lanes = NULL; road->exit_lanes = NULL; road->speed_limit = 0.0f; + road->length = 0.0f; + road->cum_lengths = NULL; } } @@ -1220,7 +1258,6 @@ int load_map_binary(const char *filename, Drive *drive) { } drive->lane_graph.n_lanes = n_lanes_graph; drive->lane_graph.lane_ids = NULL; - drive->lane_graph.lane_lengths = NULL; drive->lane_graph.distances = NULL; if (n_lanes_graph > 0) { drive->lane_graph.lane_ids = (int *) malloc(n_lanes_graph * sizeof(int)); @@ -1228,11 +1265,6 @@ int load_map_binary(const char *filename, Drive *drive) { fclose(file); return -1; } - drive->lane_graph.lane_lengths = (float *) malloc(n_lanes_graph * sizeof(float)); - if (fread(drive->lane_graph.lane_lengths, sizeof(float), n_lanes_graph, file) != (size_t) n_lanes_graph) { - fclose(file); - return -1; - } drive->lane_graph.distances = (float *) malloc(n_lanes_graph * n_lanes_graph * sizeof(float)); if (fread(drive->lane_graph.distances, sizeof(float), n_lanes_graph * n_lanes_graph, file) != (size_t) (n_lanes_graph * n_lanes_graph)) { @@ -1240,6 +1272,7 @@ int load_map_binary(const char *filename, Drive *drive) { return -1; } } + build_lane_graph_index(drive); // Metadata if (fread(drive->scenario_id, sizeof(char), 128, file) != 128) { @@ -1298,17 +1331,6 @@ int load_map_binary(const char *filename, Drive *drive) { // Road Utility Functions // ======================================== -// Compute the length of a lane -static float compute_lane_length(RoadMapElement *lane) { - float length = 0.0f; - for (int i = 1; i < lane->segment_length; i++) { - float dx = lane->x[i] - lane->x[i - 1]; - float dy = lane->y[i] - lane->y[i - 1]; - length += sqrtf(dx * dx + dy * dy); - } - return length; -} - // Compute the remaining distance on a lane from a given position to the end of the lane static float compute_remaining_lane_distance(RoadMapElement *lane, float pos_x, float pos_y) { // Find the closest segment to the position @@ -1343,23 +1365,9 @@ static float compute_remaining_lane_distance(RoadMapElement *lane, float pos_x, } } - // Compute remaining distance from closest point to end of lane - float remaining = 0.0f; - - // Partial distance in current segment (from t to end of segment) - float dx = lane->x[closest_seg + 1] - lane->x[closest_seg]; - float dy = lane->y[closest_seg + 1] - lane->y[closest_seg]; - float seg_len = sqrtf(dx * dx + dy * dy); - remaining += (1.0f - closest_t) * seg_len; - - // Full distance of remaining segments - for (int i = closest_seg + 1; i < lane->segment_length - 1; i++) { - dx = lane->x[i + 1] - lane->x[i]; - dy = lane->y[i + 1] - lane->y[i]; - remaining += sqrtf(dx * dx + dy * dy); - } - - return remaining; + float progress = lane->cum_lengths[closest_seg] + + closest_t * (lane->cum_lengths[closest_seg + 1] - lane->cum_lengths[closest_seg]); + return fmaxf(0.0f, lane->length - progress); } static float compute_lane_end_distance_sq(RoadMapElement *lane, float origin_x, float origin_y) { @@ -1373,6 +1381,118 @@ static float compute_lane_end_distance_sq(RoadMapElement *lane, float origin_x, return dx * dx + dy * dy; } +static int get_lane_graph_index(Drive *env, int lane_idx) { + if (lane_idx < 0 || lane_idx >= env->num_road_elements || env->lane_graph_index_by_lane == NULL) { + return -1; + } + return env->lane_graph_index_by_lane[lane_idx]; +} + +static int valid_route_distance(float distance) { + return isfinite(distance) && distance >= 0.0f && distance < 1e8f; +} + +static float lane_graph_distance(Drive *env, int from_lane_idx, int to_lane_idx) { + int from_idx = get_lane_graph_index(env, from_lane_idx); + int to_idx = get_lane_graph_index(env, to_lane_idx); + if (from_idx < 0 || to_idx < 0 || env->lane_graph.distances == NULL) { + return INFINITY; + } + return env->lane_graph.distances[from_idx * env->lane_graph.n_lanes + to_idx]; +} + +static float lane_s_at_position(RoadMapElement *lane, float pos_x, float pos_y) { + if (lane->cum_lengths == NULL || lane->segment_length < 2) { + return 0.0f; + } + + int closest_seg = 0; + float closest_t = 0.0f; + float min_dist_sq = 1e30f; + + for (int i = 0; i < lane->segment_length - 1; i++) { + float x0 = lane->x[i]; + float y0 = lane->y[i]; + float x1 = lane->x[i + 1]; + float y1 = lane->y[i + 1]; + float dx = x1 - x0; + float dy = y1 - y0; + float seg_len_sq = dx * dx + dy * dy; + float t = 0.0f; + if (seg_len_sq > 1e-6f) { + t = ((pos_x - x0) * dx + (pos_y - y0) * dy) / seg_len_sq; + t = fmaxf(0.0f, fminf(1.0f, t)); + } + float proj_x = x0 + t * dx; + float proj_y = y0 + t * dy; + float dist_sq = (pos_x - proj_x) * (pos_x - proj_x) + (pos_y - proj_y) * (pos_y - proj_y); + if (dist_sq < min_dist_sq) { + min_dist_sq = dist_sq; + closest_seg = i; + closest_t = t; + } + } + + float s = lane->cum_lengths[closest_seg] + + closest_t * (lane->cum_lengths[closest_seg + 1] - lane->cum_lengths[closest_seg]); + return clip(s, 0.0f, lane->length); +} + +static float lane_midpoint_s(RoadMapElement *lane, int geometry_idx) { + if (lane->cum_lengths == NULL || geometry_idx < 0 || geometry_idx >= lane->segment_length - 1) { + return 0.0f; + } + return 0.5f * (lane->cum_lengths[geometry_idx] + lane->cum_lengths[geometry_idx + 1]); +} + +static int lane_segment_at_s(RoadMapElement *lane, float lane_s) { + if (lane->segment_length < 2 || lane->cum_lengths == NULL) { + return 0; + } + float s = clip(lane_s, 0.0f, lane->length); + for (int i = 0; i < lane->segment_length - 1; i++) { + if (s <= lane->cum_lengths[i + 1]) { + return i; + } + } + return lane->segment_length - 2; +} + +static float route_distance_between_lane_positions( + Drive *env, + int from_lane_idx, + float from_s, + int to_lane_idx, + float to_s) { + if (from_lane_idx < 0 || from_lane_idx >= env->num_road_elements || to_lane_idx < 0 + || to_lane_idx >= env->num_road_elements) { + return INFINITY; + } + RoadMapElement *from_lane = &env->road_elements[from_lane_idx]; + RoadMapElement *to_lane = &env->road_elements[to_lane_idx]; + if (!is_drivable_road_lane(from_lane->type) || !is_drivable_road_lane(to_lane->type)) { + return INFINITY; + } + float from_clamped = clip(from_s, 0.0f, from_lane->length); + float to_clamped = clip(to_s, 0.0f, to_lane->length); + if (from_lane_idx == to_lane_idx) { + return (to_clamped >= from_clamped) ? (to_clamped - from_clamped) : INFINITY; + } + + float graph_distance = lane_graph_distance(env, from_lane_idx, to_lane_idx); + if (!valid_route_distance(graph_distance)) { + return INFINITY; + } + return fmaxf(0.0f, from_lane->length - from_clamped) + graph_distance + to_clamped; +} + +static int get_agent_goal_count(Drive *env, Agent *agent) { + if (env->target_type == TARGET_DIJKSTRA) { + return agent->num_active_goals; + } + return env->num_target_waypoints; +} + static float compute_progression(Agent *agent) { int num_wp = agent->path->num_waypoints; if (num_wp < 2) { @@ -1792,7 +1912,7 @@ static int generate_random_route( // Accumulate distance RoadMapElement *exit_lane = &env->road_elements[chosen_exit_idx]; - accumulated_distance += compute_lane_length(exit_lane); + accumulated_distance += exit_lane->length; if (chosen_exit_dist_sq > max_end_distance_sq) { max_end_distance_sq = chosen_exit_dist_sq; } @@ -1862,8 +1982,247 @@ static int compute_new_route(Drive *env, int agent_idx, int current_lane_idx) { return 1; // Success } +static void set_agent_goal_from_lane_s(Drive *env, Agent *agent, int goal_idx, int lane_idx, float lane_s) { + RoadMapElement *lane = &env->road_elements[lane_idx]; + int seg_idx = lane_segment_at_s(lane, lane_s); + float s0 = lane->cum_lengths[seg_idx]; + float s1 = lane->cum_lengths[seg_idx + 1]; + float denom = fmaxf(s1 - s0, 1e-6f); + float t = clip((lane_s - s0) / denom, 0.0f, 1.0f); + agent->goal_positions_x[goal_idx] = lane->x[seg_idx] + t * (lane->x[seg_idx + 1] - lane->x[seg_idx]); + agent->goal_positions_y[goal_idx] = lane->y[seg_idx] + t * (lane->y[seg_idx + 1] - lane->y[seg_idx]); + agent->goal_positions_z[goal_idx] = lane->z[seg_idx] + t * (lane->z[seg_idx + 1] - lane->z[seg_idx]); + agent->goal_lane_ids[goal_idx] = lane_idx; + agent->goal_lane_s[goal_idx] = clip(lane_s, 0.0f, lane->length); +} + +static int build_dijkstra_route_to_goal( + Drive *env, + int start_lane_idx, + int goal_lane_idx, + int *route, + int max_route_length) { + if (start_lane_idx < 0 || goal_lane_idx < 0 || max_route_length <= 0) { + return 0; + } + int current_lane_idx = start_lane_idx; + int route_length = 0; + route[route_length++] = current_lane_idx; + + for (int step = 0; step < DIJKSTRA_MAX_ROUTE_ATTEMPTS && route_length < max_route_length; step++) { + if (current_lane_idx == goal_lane_idx) { + return route_length; + } + RoadMapElement *current_lane = &env->road_elements[current_lane_idx]; + float current_dist = lane_graph_distance(env, current_lane_idx, goal_lane_idx); + if (!valid_route_distance(current_dist)) { + return 0; + } + + int best_exit = -1; + float best_dist = current_dist; + for (int e = 0; e < current_lane->num_exits; e++) { + int exit_lane_idx = current_lane->exit_lanes[e]; + if (exit_lane_idx < 0 || exit_lane_idx >= env->num_road_elements) { + continue; + } + if (!is_drivable_road_lane(env->road_elements[exit_lane_idx].type)) { + continue; + } + float exit_dist = lane_graph_distance(env, exit_lane_idx, goal_lane_idx); + if (valid_route_distance(exit_dist) && exit_dist < best_dist - 1e-3f) { + best_dist = exit_dist; + best_exit = exit_lane_idx; + } + } + if (best_exit == -1) { + return 0; + } + route[route_length++] = best_exit; + current_lane_idx = best_exit; + } + return current_lane_idx == goal_lane_idx ? route_length : 0; +} + +static int set_agent_route_from_buffer(Drive *env, int agent_idx, int *route, int route_length) { + Agent *agent = &env->agents[agent_idx]; + if (route_length <= 0) { + return 0; + } + free(agent->route); + agent->route = (int *) malloc(route_length * sizeof(int)); + if (agent->route == NULL) { + agent->route_length = 0; + return 0; + } + agent->route_length = route_length; + for (int i = 0; i < route_length; i++) { + agent->route[i] = route[i]; + } + agent->current_route_index = 0; + build_path(env, agent_idx); + agent->closest_path_idx_wp = 0; + agent->closest_path_idx_wp = get_closest_waypoint_index_on_path(env, agent_idx); + agent->path_progression = compute_progression(agent); + return 1; +} + +static int rebuild_dijkstra_route_to_current_goal(Drive *env, int agent_idx) { + Agent *agent = &env->agents[agent_idx]; + int goal_count = get_agent_goal_count(env, agent); + if (agent->current_goal_idx < 0 || agent->current_goal_idx >= goal_count) { + return 0; + } + int start_lane_idx = agent->current_lane_idx; + if (start_lane_idx < 0) { + start_lane_idx = agent->previous_lane_idx; + } + int goal_lane_idx = agent->goal_lane_ids[agent->current_goal_idx]; + int temp_route[MAX_ROUTE_LENGTH]; + int route_length = build_dijkstra_route_to_goal(env, start_lane_idx, goal_lane_idx, temp_route, MAX_ROUTE_LENGTH); + return set_agent_route_from_buffer(env, agent_idx, temp_route, route_length); +} + +static int compute_dijkstra_goals(Drive *env, int agent_idx) { + assert(env != NULL); + assert(agent_idx >= 0 && agent_idx < env->num_total_agents); + + Agent *agent = &env->agents[agent_idx]; + int active_goal_count = 1 + (rand() % DIJKSTRA_TARGET_SLOTS); + for (int i = 0; i < MAX_TARGET_WAYPOINTS; i++) { + agent->goal_positions_x[i] = 0.0f; + agent->goal_positions_y[i] = 0.0f; + agent->goal_positions_z[i] = 0.0f; + agent->goal_lane_ids[i] = -1; + agent->goal_lane_s[i] = 0.0f; + } + + int start_lane_idx = agent->current_lane_idx; + if (start_lane_idx < 0) { + return 0; + } + RoadMapElement *start_lane = &env->road_elements[start_lane_idx]; + float start_s = lane_s_at_position(start_lane, agent->sim_x, agent->sim_y); + + int first_route[MAX_ROUTE_LENGTH]; + int first_route_length = 0; + int first_goal_lane_idx = -1; + float first_goal_lane_s = 0.0f; + + for (int attempt = 0; attempt < 10; attempt++) { + int idx = rand() % env->num_road_elements; + if (!is_drivable_road_lane(env->road_elements[idx].type)) { + continue; + } + float graph_dist = lane_graph_distance(env, start_lane_idx, idx); + if (!valid_route_distance(graph_dist)) { + continue; + } + float lane_len = env->road_elements[idx].length; + float s = random_uniform(0.0f, lane_len); + float total_dist = graph_dist + s - start_s; + if (total_dist < DIJKSTRA_MIN_GOAL_DISTANCE) { + continue; + } + int temp_route[MAX_ROUTE_LENGTH]; + int route_length = build_dijkstra_route_to_goal(env, start_lane_idx, idx, temp_route, MAX_ROUTE_LENGTH); + if (route_length > 0) { + first_goal_lane_idx = idx; + first_route_length = route_length; + first_goal_lane_s = s; + for (int i = 0; i < route_length; i++) { + first_route[i] = temp_route[i]; + } + break; + } + } + + if (first_goal_lane_idx == -1) { + first_goal_lane_idx = start_lane_idx; + first_route_length = 1; + first_route[0] = start_lane_idx; + float fallback_s = start_s + DIJKSTRA_MIN_GOAL_DISTANCE; + float lane_len = env->road_elements[start_lane_idx].length; + if (fallback_s > lane_len) { + fallback_s = lane_len; + } + first_goal_lane_s = fallback_s; + } + + set_agent_goal_from_lane_s(env, agent, 0, first_goal_lane_idx, first_goal_lane_s); + + int current_lane_idx = first_goal_lane_idx; + float current_s = first_goal_lane_s; + int sampled_count = 1; + + for (int goal_idx = 1; goal_idx < active_goal_count; goal_idx++) { + float remaining_dist = random_uniform(DIJKSTRA_MIN_GOAL_DISTANCE + 5.0f, DIJKSTRA_MAX_GOAL_DISTANCE - 5.0f); + int next_lane_idx = current_lane_idx; + float next_s = current_s; + + int step = 0; + int max_steps = 100; + while (remaining_dist > 0.0f && step < max_steps) { + step++; + RoadMapElement *lane = &env->road_elements[next_lane_idx]; + float segment_rem = lane->length - next_s; + if (remaining_dist <= segment_rem) { + next_s += remaining_dist; + remaining_dist = 0.0f; + break; + } else { + remaining_dist -= segment_rem; + if (lane->num_exits > 0) { + int valid_exits[MAX_ROUTE_LENGTH]; + int valid_count = 0; + for (int e = 0; e < lane->num_exits && e < MAX_ROUTE_LENGTH; e++) { + int exit_lane_idx = lane->exit_lanes[e]; + if (exit_lane_idx >= 0 && exit_lane_idx < env->num_road_elements) { + if (is_drivable_road_lane(env->road_elements[exit_lane_idx].type)) { + valid_exits[valid_count++] = exit_lane_idx; + } + } + } + if (valid_count > 0) { + next_lane_idx = valid_exits[rand() % valid_count]; + next_s = 0.0f; + } else { + next_s = lane->length; + remaining_dist = 0.0f; + break; + } + } else { + next_s = lane->length; + remaining_dist = 0.0f; + break; + } + } + } + + set_agent_goal_from_lane_s(env, agent, goal_idx, next_lane_idx, next_s); + current_lane_idx = next_lane_idx; + current_s = next_s; + sampled_count++; + } + + agent->num_active_goals = sampled_count; + agent->current_goal_idx = 0; + agent->goal_position_x = agent->goal_positions_x[0]; + agent->goal_position_y = agent->goal_positions_y[0]; + agent->goal_position_z = agent->goal_positions_z[0]; + return set_agent_route_from_buffer(env, agent_idx, first_route, first_route_length); +} + static void compute_goals(Drive *env, int agent_idx) { Agent *agent = &env->agents[agent_idx]; + if (env->target_type == TARGET_DIJKSTRA) { + if (!compute_dijkstra_goals(env, agent_idx)) { + printf("[GIGAFLOW WARNING] -> Failed to compute dijkstra goals for agent %d\n", agent_idx); + agent->removed = 1; + } + return; + } + struct Path *path = agent->path; // Validate path exists @@ -1930,6 +2289,12 @@ static void compute_goals(Drive *env, int agent_idx) { agent->goal_positions_x[i] = path->waypoints[wp_idx].x; agent->goal_positions_y[i] = path->waypoints[wp_idx].y; agent->goal_positions_z[i] = path->waypoints[wp_idx].z; + int goal_lane_idx = path->waypoints[wp_idx].lane_idx; + agent->goal_lane_ids[i] = goal_lane_idx; + agent->goal_lane_s[i] = lane_s_at_position( + &env->road_elements[goal_lane_idx], + agent->goal_positions_x[i], + agent->goal_positions_y[i]); } // Reset goal index and update alias @@ -3009,14 +3374,22 @@ static int spawn_agent(Drive *env, int agent_idx, int num_agents) { agent->yaw_rate = 0.0f; update_agent_speed(agent); - // Compute initial route - if (!compute_new_route(env, agent_idx, start_lane_idx)) { - printf("[GIGAFLOW WARNING] -> Failed to compute a new route for agent %d\n", agent_idx); - return 0; // Failed to compute new goal - } + if (env->target_type == TARGET_DIJKSTRA) { + agent->current_lane_idx = start_lane_idx; + if (!compute_dijkstra_goals(env, agent_idx)) { + printf("[GIGAFLOW WARNING] -> Failed to compute dijkstra goals for agent %d\n", agent_idx); + return 0; + } + } else { + // Compute initial route + if (!compute_new_route(env, agent_idx, start_lane_idx)) { + printf("[GIGAFLOW WARNING] -> Failed to compute a new route for agent %d\n", agent_idx); + return 0; // Failed to compute new goal + } - // Compute initial goal - compute_goals(env, agent_idx); + // Compute initial goal + compute_goals(env, agent_idx); + } return 1; // Success } @@ -3361,6 +3734,9 @@ void remove_bad_trajectories(Drive *env) { void init(Drive *env) { env->human_agent_idx = 0; env->timestep = 0; + if (env->target_type == TARGET_DIJKSTRA) { + env->num_target_waypoints = DIJKSTRA_TARGET_SLOTS; + } load_map_binary(env->map_name, env); env->road_dropout_enabled = (env->obs_lane_segment_count < env->max_lane_segment_observations) || (env->obs_boundary_segment_count < env->max_boundary_segment_observations); @@ -3491,7 +3867,7 @@ static int compute_observation_size(Drive *env) { if (env->reward_conditioning) { max_obs += NUM_REWARD_COEFS; } - if (env->target_type == TARGET_STATIC) { + if (env->target_type == TARGET_STATIC || env->target_type == TARGET_DIJKSTRA) { max_obs += num_target_waypoints * STATIC_TARGET_FEATURES; } else if (env->target_type == TARGET_DYNAMIC) { max_obs += num_target_waypoints * DYNAMIC_TARGET_FEATURES; @@ -3976,9 +4352,8 @@ static void compute_metrics(Drive *env, int agent_idx) { = compute_euclidean_distance(agent->sim_x, agent->sim_y, agent->goal_position_x, agent->goal_position_y); float goal_z_dist = fabsf(agent->sim_z - agent->goal_position_z); - // Goal reaching — guard against incrementing past num_target_waypoints - if (agent->current_goal_idx < env->num_target_waypoints - && distance_to_goal < agent->reward_coefs[REWARD_COEF_GOAL_RADIUS] && goal_z_dist < Z_BUFFER) { + // Goal reaching + if (distance_to_goal < agent->reward_coefs[REWARD_COEF_GOAL_RADIUS] && goal_z_dist < Z_BUFFER) { agent->metrics_array[REACHED_GOAL_IDX] = 1.0f; agent->current_goal_idx++; } @@ -4025,8 +4400,9 @@ static void compute_rewards(Drive *env, int i) { // Goal reward if (agent->metrics_array[REACHED_GOAL_IDX] > 0.0f) { float weight = 1.0f; + int goal_count = get_agent_goal_count(env, agent); if (env->simulation_mode == SIMULATION_GIGAFLOW) { - if (agent->current_goal_idx == env->num_target_waypoints + if (agent->current_goal_idx == goal_count && agent->sim_speed > agent->reward_coefs[REWARD_COEF_GOAL_SPEED]) { weight = 0.0f; } @@ -4228,9 +4604,10 @@ static void compute_observations(Drive *env) { } } // Target observations (static or dynamic) - if (env->target_type == TARGET_STATIC) { + if (env->target_type == TARGET_STATIC || env->target_type == TARGET_DIJKSTRA) { + int goal_count = get_agent_goal_count(env, ego_entity); for (int wp = 0; wp < env->num_target_waypoints; wp++) { - if (wp < ego_entity->current_goal_idx) { + if (wp < ego_entity->current_goal_idx || wp >= goal_count) { // Already reached - zeroed obs[obs_idx++] = 0.0f; obs[obs_idx++] = 0.0f; @@ -4401,6 +4778,24 @@ static void compute_observations(Drive *env) { float *boundaries_dest = env->road_dropout_enabled ? boundaries_buffer : &obs[boundary_obs_idx]; int lanes_collected = 0; int boundaries_collected = 0; + int goal_count = get_agent_goal_count(env, ego_entity); + int goal_lane_idx = -1; + float goal_lane_s = 0.0f; + float ego_goal_route_distance = INFINITY; + if (ego_entity->current_goal_idx >= 0 && ego_entity->current_goal_idx < goal_count) { + goal_lane_idx = ego_entity->goal_lane_ids[ego_entity->current_goal_idx]; + goal_lane_s = ego_entity->goal_lane_s[ego_entity->current_goal_idx]; + } + if (goal_lane_idx >= 0 && ego_entity->current_lane_idx >= 0) { + RoadMapElement *ego_lane = &env->road_elements[ego_entity->current_lane_idx]; + float ego_lane_s = lane_s_at_position(ego_lane, ego_entity->sim_x, ego_entity->sim_y); + ego_goal_route_distance = route_distance_between_lane_positions( + env, + ego_entity->current_lane_idx, + ego_lane_s, + goal_lane_idx, + goal_lane_s); + } for (int k = 0; k < list_size; k++) { if (lanes_collected >= env->max_lane_segment_observations @@ -4503,6 +4898,28 @@ static void compute_observations(Drive *env) { target[base + 5] = cos_angle; // Road segment orientation (sine) target[base + 6] = sin_angle; + if (is_lane) { + float segment_s = lane_midpoint_s(element, geometry_idx); + float segment_goal_distance + = route_distance_between_lane_positions(env, entity_idx, segment_s, goal_lane_idx, goal_lane_s); + if (valid_route_distance(segment_goal_distance)) { + target[base + 7] = clip(segment_goal_distance / env->max_goal_position, -1.0f, 1.0f); + if (valid_route_distance(ego_goal_route_distance)) { + target[base + 8] = clip( + (segment_goal_distance - ego_goal_route_distance) / env->max_goal_position, + -1.0f, + 1.0f); + } else { + target[base + 8] = 0.0f; + } + } else { + target[base + 7] = 1.0f; + target[base + 8] = 0.0f; + } + } else { + target[base + 7] = 0.0f; + target[base + 8] = 0.0f; + } } if (env->road_dropout_enabled) { @@ -5095,8 +5512,9 @@ void c_step(Drive *env) { int agent_idx = env->active_agent_indices[i]; Agent *agent = &env->agents[agent_idx]; if (agent->metrics_array[REACHED_GOAL_IDX] > 0.0f) { - if (agent->current_goal_idx == env->num_target_waypoints) { - // Last goal reached + int goal_count = get_agent_goal_count(env, agent); + if (agent->current_goal_idx == goal_count) { + // Last goal reached - generate new set of goals env->logs[i].num_goals_reached += 1; if (env->simulation_mode == SIMULATION_REPLAY) { // Replay mode: leave current_goal_idx saturated so the @@ -5110,6 +5528,9 @@ void c_step(Drive *env) { agent->goal_position_x = agent->goal_positions_x[agent->current_goal_idx]; agent->goal_position_y = agent->goal_positions_y[agent->current_goal_idx]; agent->goal_position_z = agent->goal_positions_z[agent->current_goal_idx]; + if (env->target_type == TARGET_DIJKSTRA) { + rebuild_dijkstra_route_to_current_goal(env, agent_idx); + } } } } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index d49563ebf6..339262039c 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -141,8 +141,12 @@ def __init__( self.target_type = binding.TARGET_STATIC elif target_type == "dynamic": self.target_type = binding.TARGET_DYNAMIC + elif target_type == "dijkstra": + self.target_type = binding.TARGET_DIJKSTRA + self.num_target_waypoints = 4 + num_target_waypoints = 4 else: - raise ValueError(f"target_type must be 'static' or 'dynamic'. Got: {target_type}") + raise ValueError(f"target_type must be 'static', 'dynamic', or 'dijkstra'. Got: {target_type}") self.collision_behavior = collision_behavior self.offroad_behavior = offroad_behavior self.traffic_light_behavior = traffic_light_behavior @@ -213,7 +217,7 @@ def __init__( self.num_reward_coefs = binding.NUM_REWARD_COEFS if reward_conditioning else 0 # Target features based on target_type - if target_type == "static": + if target_type == "static" or target_type == "dijkstra": self.target_features = binding.STATIC_TARGET_FEATURES else: self.target_features = binding.DYNAMIC_TARGET_FEATURES diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 4822895dfb..f54bef3bf1 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -1268,6 +1268,7 @@ PyMODINIT_FUNC PyInit_binding(void) { PyModule_AddIntConstant(m, "NUM_REWARD_COEFS", NUM_REWARD_COEFS); PyModule_AddIntConstant(m, "TARGET_STATIC", TARGET_STATIC); PyModule_AddIntConstant(m, "TARGET_DYNAMIC", TARGET_DYNAMIC); + PyModule_AddIntConstant(m, "TARGET_DIJKSTRA", TARGET_DIJKSTRA); PyObject_SetAttrString(m, "MULTI_LANE_FULL_SCORE_TIME", PyFloat_FromDouble(MULTI_LANE_FULL_SCORE_TIME)); PyObject_SetAttrString(m, "MULTI_LANE_HALF_SCORE_TIME", PyFloat_FromDouble(MULTI_LANE_HALF_SCORE_TIME)); diff --git a/pufferlib/ocean/env_config.h b/pufferlib/ocean/env_config.h index 25d18a513e..d2c9cc5880 100644 --- a/pufferlib/ocean/env_config.h +++ b/pufferlib/ocean/env_config.h @@ -94,6 +94,8 @@ static int handler(void *config, const char *section, const char *name, const ch env_config->target_type = 0; // TARGET_STATIC } else if (strcmp(value, "\"dynamic\"") == 0 || strcmp(value, "dynamic") == 0) { env_config->target_type = 1; // TARGET_DYNAMIC + } else if (strcmp(value, "\"dijkstra\"") == 0 || strcmp(value, "dijkstra") == 0) { + env_config->target_type = 2; // TARGET_DIJKSTRA } else { printf("Warning: Unknown target_type value '%s', defaulting to static\n", value); env_config->target_type = 0; diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town01.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town01.bin index 7cdff08dd0..242e0a252a 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town01.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town01.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town02.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town02.bin index 5eed725922..0c1f829abb 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town02.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town02.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town03.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town03.bin index e5db89ffdd..30a47d4fcb 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town03.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town03.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town04.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town04.bin index 38d42572f1..0dcbabfcab 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town04.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town04.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town05.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town05.bin index 4836a5dc37..e7d344333a 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town05.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town05.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town06.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town06.bin index bfe8c024f5..506ed84d90 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town06.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town06.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town07.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town07.bin index 1732c594fd..ea45f1cd08 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town07.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town07.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin index bd913dee5a..a7da5036e8 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin differ diff --git a/pufferlib/viz.py b/pufferlib/viz.py index a15e3fbc8e..5917309858 100644 --- a/pufferlib/viz.py +++ b/pufferlib/viz.py @@ -68,6 +68,7 @@ MULTI_LANE_FULL_SCORE_TIME = binding.MULTI_LANE_FULL_SCORE_TIME MULTI_LANE_HALF_SCORE_TIME = binding.MULTI_LANE_HALF_SCORE_TIME +PADDED_OBSERVATION_VALUE = -0.001 METRIC_LABELS = [ "collision", @@ -154,6 +155,10 @@ def _scale_ratio(numerator, denominator, default=1.0): return default if denominator == 0 else float(numerator) / float(denominator) +def _is_empty_obs_row(row): + return np.all(row == 0) or np.all(row == PADDED_OBSERVATION_VALUE) + + def _obs_scales( env_cfg=None, max_goal_position=100.0, @@ -181,6 +186,38 @@ def _obs_scales( } +def _target_type_name(target_type): + if target_type == binding.TARGET_DYNAMIC: + return "dynamic" + if target_type == binding.TARGET_DIJKSTRA: + return "dijkstra" + if isinstance(target_type, str): + target_type = target_type.strip('"') + if target_type == str(binding.TARGET_STATIC): + return "static" + if target_type == str(binding.TARGET_DYNAMIC): + return "dynamic" + if target_type == str(binding.TARGET_DIJKSTRA): + return "dijkstra" + return target_type + return "static" + + +def _dynamics_model_name(dynamics_model): + if dynamics_model == 1: + return "jerk" + if isinstance(dynamics_model, str): + dynamics_model = dynamics_model.strip('"') + if dynamics_model == "0": + return "classic" + return "jerk" if dynamics_model == "1" else dynamics_model + return "classic" + + +def _target_waypoint_count(target_type, num_target_waypoints): + return 4 if _target_type_name(target_type) == "dijkstra" else num_target_waypoints + + def _init_fig_ax(config: VizConfig, reuse_key: str = None, with_metrics: bool = False): cache_key = f"{reuse_key}_{'metrics' if with_metrics else 'single'}" if reuse_key else None @@ -852,11 +889,11 @@ def unpack_obs( dynamics_model: int = 0, target_type: str = "static", reward_conditioning: bool = False, - num_target_waypoints: int = 5, + num_target_waypoints: int = 3, max_partners: int = 16, - max_lane_segments: int = 16, - max_boundary_segments: int = 16, - max_traffic_control_observations: int = 16, + max_lane_segments: int = 32, + max_boundary_segments: int = 32, + max_traffic_control_observations: int = 4, lane_segment_dropout: float = 0.0, boundary_segment_dropout: float = 0.0, agent_idx: int = 0, @@ -873,6 +910,9 @@ def unpack_obs( if obs_flat.ndim == 1: obs_flat = obs_flat[None, :] + dynamics_model = _dynamics_model_name(dynamics_model) + target_type = _target_type_name(target_type) + num_target_waypoints = _target_waypoint_count(target_type, num_target_waypoints) ego_dim = binding.EGO_FEATURES_JERK if dynamics_model == "jerk" else binding.EGO_FEATURES_CLASSIC # Partner obs @@ -885,7 +925,7 @@ def unpack_obs( boundary_segment_count = compute_effective_road_obs_count(max_boundary_segments, boundary_segment_dropout) # Target obs - target_features = binding.STATIC_TARGET_FEATURES if target_type == "static" else binding.DYNAMIC_TARGET_FEATURES + target_features = binding.DYNAMIC_TARGET_FEATURES if target_type == "dynamic" else binding.STATIC_TARGET_FEATURES target_dim = num_target_waypoints * target_features # Extract ego state @@ -902,20 +942,29 @@ def unpack_obs( # Extract partners partners_start = target_end partners_end = partners_start + max_partners * partner_feature_size - partners_obs = obs_flat[:, partners_start:partners_end] - partners_obs = partners_obs.reshape(-1, max_partners, partner_feature_size) + if max_partners > 0: + partners_obs = obs_flat[:, partners_start:partners_end] + partners_obs = partners_obs.reshape(-1, max_partners, partner_feature_size) + else: + partners_obs = np.zeros((obs_flat.shape[0], 0, partner_feature_size)) # Extract lane elements lane_start = partners_end lane_end = lane_start + lane_segment_count * road_feature_size - lane_obs = obs_flat[:, lane_start:lane_end] - lane_obs = lane_obs.reshape(-1, lane_segment_count, road_feature_size) + if lane_segment_count > 0: + lane_obs = obs_flat[:, lane_start:lane_end] + lane_obs = lane_obs.reshape(-1, lane_segment_count, road_feature_size) + else: + lane_obs = np.zeros((obs_flat.shape[0], 0, road_feature_size)) # Extract boundary elements boundary_start = lane_end boundary_end = boundary_start + boundary_segment_count * road_feature_size - boundary_obs = obs_flat[:, boundary_start:boundary_end] - boundary_obs = boundary_obs.reshape(-1, boundary_segment_count, road_feature_size) + if boundary_segment_count > 0: + boundary_obs = obs_flat[:, boundary_start:boundary_end] + boundary_obs = boundary_obs.reshape(-1, boundary_segment_count, road_feature_size) + else: + boundary_obs = np.zeros((obs_flat.shape[0], 0, road_feature_size)) # Extract traffic controls traffic_start = boundary_end @@ -943,7 +992,7 @@ def plot_observation( dynamics_model="classic", target_type="static", reward_conditioning=False, - num_target_waypoints=10, + num_target_waypoints=3, max_partners=16, max_lane_segments=32, max_boundary_segments=32, @@ -966,6 +1015,9 @@ def plot_observation( target_type: 0 for goal only, 1 for waypoints only, 2 for both """ fig, ax = plt.subplots(figsize=(20, 20)) + dynamics_model = _dynamics_model_name(dynamics_model) + target_type = _target_type_name(target_type) + num_target_waypoints = _target_waypoint_count(target_type, num_target_waypoints) ego_state, target_obs, partners_obs, lane_obs, boundary_obs, traffic_controls_obs = unpack_obs( obs, @@ -989,12 +1041,14 @@ def plot_observation( max_road_segment_length=max_road_segment_length, max_road_segment_width=max_road_segment_width, ) - target_position_scale = scales["goal_to_position"] if target_type == "static" else 1.0 + target_position_scale = scales["goal_to_position"] if target_type != "dynamic" else 1.0 if dynamics_model == "jerk": - ego_speed, ego_width, ego_length, steering_angle, a_long, a_lat, lcenter, lalign, speed_limit = ego_state + ego_speed, ego_width, ego_length, steering_angle, a_long, a_lat, lcenter, lalign, speed_limit, stopped = ( + ego_state + ) else: - ego_speed, ego_width, ego_length, lcenter, lalign, speed_limit = ego_state + ego_speed, ego_width, ego_length, steering_angle, lcenter, lalign, speed_limit, stopped = ego_state ego_width *= scales["veh_width_to_position"] ego_length *= scales["veh_len_to_position"] @@ -1005,26 +1059,13 @@ def plot_observation( (-ego_length / 2, -ego_width / 2), ego_length, ego_width, - facecolor="#0055FF", - edgecolor="#FFD700", - linewidth=4, - alpha=0.9, + facecolor="blue", + edgecolor="black", + linewidth=2, + alpha=0.7, zorder=10, ) ) - # SDC label above the vehicle - ax.text( - 0, - ego_width / 2 + 0.03, - "SDC", - ha="center", - va="bottom", - fontsize=11, - fontweight="bold", - color="#FFD700", - bbox=dict(boxstyle="round,pad=0.2", facecolor="#0055FF", edgecolor="#FFD700", linewidth=1.5), - zorder=11, - ) # Draw target waypoints for i in range(target_obs.shape[0]): @@ -1032,7 +1073,7 @@ def plot_observation( continue wp_x = target_obs[i][0] * target_position_scale wp_y = target_obs[i][1] * target_position_scale - if target_type == "static": + if target_type != "dynamic": color = "red" if i == 0 else "orange" marker = "*" if i == 0 else "o" s = 200 if i == 0 else 80 @@ -1043,10 +1084,10 @@ def plot_observation( ax.scatter(wp_x, wp_y, color=color, marker=marker, s=s, zorder=15) # Add dynamics info text for JERK model - ego_info = f"Speed: {ego_speed:.2f}\nLane Centering: {lcenter:.2f}\nLane Align: {lalign:.2f}\nSpeed Limit: {speed_limit:.2f}" + ego_info = f"Speed: {ego_speed:.2f}\nSteering: {steering_angle:.3f}\nLane Centering: {lcenter:.2f}\nLane Align: {lalign:.2f}\nSpeed Limit: {speed_limit:.2f}\nStopped: {stopped:.2f}" if dynamics_model == "jerk": - ego_info += f"\nSteering: {steering_angle:.3f}\na_long: {a_long:.2f}\na_lat: {a_lat:.2f}" + ego_info += f"\na_long: {a_long:.2f}\na_lat: {a_lat:.2f}" ax.text( 0.02, @@ -1060,7 +1101,7 @@ def plot_observation( # Partner agents for i in range(partners_obs.shape[0]): - if np.all(partners_obs[i] == 0): + if _is_empty_obs_row(partners_obs[i]): continue rel_x, rel_y = partners_obs[i][0], partners_obs[i][1] length = partners_obs[i][3] * scales["veh_len_to_position"] @@ -1086,7 +1127,7 @@ def plot_observation( rw2p = scales["road_width_to_position"] count_lane = 0 for i in range(lane_obs.shape[0]): - if np.all(lane_obs[i] == 0): + if _is_empty_obs_row(lane_obs[i]): continue count_lane += 1 rel_x, rel_y = lane_obs[i][0], lane_obs[i][1] @@ -1095,16 +1136,15 @@ def plot_observation( color = "lightgrey" ax.scatter(rel_x, rel_y, color=color, s=10, zorder=1) ax.plot( - [rel_x + dir_cos * length / 2, rel_x - dir_cos * length / 2], - [rel_y + dir_sin * length / 2, rel_y - dir_sin * length / 2], + [rel_x + dir_cos * length, rel_x - dir_cos * length], + [rel_y + dir_sin * length, rel_y - dir_sin * length], color=color, linewidth=1, zorder=1, ) - count_boundary = 0 for i in range(boundary_obs.shape[0]): - if np.all(boundary_obs[i] == 0): + if _is_empty_obs_row(boundary_obs[i]): continue count_boundary += 1 rel_x, rel_y = boundary_obs[i][0], boundary_obs[i][1] @@ -1113,8 +1153,8 @@ def plot_observation( color = "black" ax.scatter(rel_x, rel_y, color=color, s=10, zorder=1) ax.plot( - [rel_x + dir_cos * length / 2, rel_x - dir_cos * length / 2], - [rel_y + dir_sin * length / 2, rel_y - dir_sin * length / 2], + [rel_x + dir_cos * length, rel_x - dir_cos * length], + [rel_y + dir_sin * length, rel_y - dir_sin * length], color=color, linewidth=1, zorder=1, @@ -1132,10 +1172,12 @@ def plot_observation( # Traffic controls for i in range(traffic_controls_obs.shape[0]): - if np.all(traffic_controls_obs[i] == 0): + if _is_empty_obs_row(traffic_controls_obs[i]): continue rel_x1, rel_y1, rel_x2, rel_y2, _, control_type, state = traffic_controls_obs[i] control_type = int(control_type) + if _traffic_control_kind(control_type) is None: + continue if control_type == binding.TRAFFIC_CONTROL_TYPE_TRAFFIC_LIGHT: ax.plot( [rel_x1, rel_x2], @@ -1276,12 +1318,15 @@ def fill_trajectories(scenario, timestep): def extract_obs_frame(obs, scenario, args, timestep, obs_index=0, agent_idx=0, head_north=False): + dynamics_model = _dynamics_model_name(args["env"]["dynamics_model"]) + target_type = _target_type_name(args["env"]["target_type"]) + num_target_waypoints = _target_waypoint_count(target_type, args["env"]["num_target_waypoints"]) ego_state, target_obs, partners_obs, lane_obs, boundary_obs, traffic_controls_obs = unpack_obs( obs, - dynamics_model=args["env"]["dynamics_model"], - target_type=args["env"]["target_type"], + dynamics_model=dynamics_model, + target_type=target_type, reward_conditioning=args["env"]["reward_conditioning"], - num_target_waypoints=args["env"]["num_target_waypoints"], + num_target_waypoints=num_target_waypoints, max_partners=args["env"]["max_partner_observations"], max_lane_segments=args["env"]["max_lane_segment_observations"], max_boundary_segments=args["env"]["max_boundary_segment_observations"], @@ -1291,7 +1336,8 @@ def extract_obs_frame(obs, scenario, args, timestep, obs_index=0, agent_idx=0, h agent_idx=obs_index, ) scales = _obs_scales(args.get("env")) - target_position_scale = scales["goal_to_position"] if args["env"]["target_type"] == "static" else 1.0 + max_position = scales["max_position"] + target_position_scale = scales["goal_to_position"] if target_type != "dynamic" else 1.0 # --- Rotation Helper --- def _rot(x, y): @@ -1299,11 +1345,11 @@ def _rot(x, y): return (-y, x) if head_north else (x, y) # --- Parse Ego --- - if args["env"]["dynamics_model"] == "jerk": + if dynamics_model == "jerk": ego_speed, ego_width, ego_length, steering_angle, a_long, a_lat = ego_state[:6] else: - ego_speed, ego_width, ego_length = ego_state[:3] - steering_angle, a_long, a_lat = 0.0, 0.0, 0.0 + ego_speed, ego_width, ego_length, steering_angle = ego_state[:4] + a_long, a_lat = 0.0, 0.0 ego_width *= scales["veh_width_to_position"] ego_length *= scales["veh_len_to_position"] @@ -1324,7 +1370,7 @@ def _rot(x, y): def parse_roads(roads): res = [] for r in roads: - if np.all(r == 0): + if _is_empty_obs_row(r): continue x, y = r[0], r[1] length, width = r[3] * rl2p, r[4] * rw2p @@ -1335,22 +1381,21 @@ def parse_roads(roads): else: x_rot, y_rot = x, y cos_rot, sin_rot = cos_a, sin_a - res.append( - [ - round(float(x_rot), 4), - round(float(y_rot), 4), - round(float(length), 4), - round(float(width), 4), - round(float(cos_rot), 4), - round(float(sin_rot), 4), - ] - ) + row = [ + round(float(x_rot), 4), + round(float(y_rot), 4), + round(float(length), 4), + round(float(width), 4), + round(float(cos_rot), 4), + round(float(sin_rot), 4), + ] + res.append(row) return res # --- Parse Partners --- parsed_partners = [] for p in partners_obs: - if np.all(p == 0): + if _is_empty_obs_row(p): continue px, py = _rot(p[0], p[1]) @@ -1376,7 +1421,7 @@ def parse_roads(roads): # --- Parse Traffic Controls --- parsed_traffic_controls = [] for t in traffic_controls_obs: - if np.all(t == 0): + if _is_empty_obs_row(t): continue kind = _traffic_control_kind(t[5]) if kind is None: @@ -1410,6 +1455,7 @@ def parse_roads(roads): gps_data.append([round(float(gx), 3), round(float(gy), 3)]) return { + "target_type": target_type, "ego": ego_data, "partners": parsed_partners, "lanes": parse_roads(lane_obs), @@ -1417,6 +1463,11 @@ def parse_roads(roads): "traffic_controls": parsed_traffic_controls, "traj": traj_data, "gps": gps_data, + "view": { + "front": round(float(args["env"].get("road_obs_front_dist", max_position)) / max_position, 4), + "behind": round(float(args["env"].get("road_obs_behind_dist", max_position)) / max_position, 4), + "side": round(float(args["env"].get("road_obs_side_dist", max_position)) / max_position, 4), + }, } @@ -1510,10 +1561,6 @@ def round_floats(o): .panel { background: var(--panel-bg); padding: 18px; border-radius: 16px; box-shadow: 0 8px 30px var(--shadow); pointer-events: auto; backdrop-filter: blur(5px); } #hud-global { position: absolute; top: 20px; left: 20px; min-width: 220px; } - #hud-global h3 { cursor: pointer; user-select: none; } - #hud-global.collapsed { min-width: 0; padding-bottom: 14px; } - #hud-global.collapsed > :not(h3) { display: none; } - #hud-global.collapsed h3 { margin: 0; } #hud-telemetry { position: absolute; top: 80px; right: 20px; width: 340px; display: none; border-left: 6px solid var(--accent); background: rgba(15, 15, 15, 0.98); color: white; z-index: 20; } @@ -1575,8 +1622,8 @@ def round_floats(o):
SPACE: Play | ARROWS: Step | ESC: Free | CLICK: Follow | ENTER: Search
-