Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions baselines/ppo/config/ppo_base_puffer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ train:
checkpoint_path: "./runs"

# # # Rendering # # #
render: false # Determines whether to render the environment (note: will slow down training)
render: true # Determines whether to render the environment (note: will slow down training)
render_backend: "array" # Options: matplotlib, array
render_3d: true # Render simulator state in 3d or 2d
render_interval: 1 # Render every k iterations
render_k_scenarios: 10 # Number of scenarios to render
render_k_scenarios: 2 # Number of scenarios to render
render_format: "mp4" # Options: gif, mp4
render_fps: 15 # Frames per second
zoom_radius: 50
Expand Down
37 changes: 25 additions & 12 deletions gpudrive/env/env_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
from gpudrive.env.config import EnvConfig, RenderConfig

from gpudrive.env.env_torch import GPUDriveTorchEnv
from gpudrive.datatypes.observation import (
LocalEgoState,
)

from gpudrive.visualize.utils import img_from_fig
from gpudrive.env.dataset import SceneDataLoader
from gpudrive.visualize.utils import color_onehot_segmentation_map

from pufferlib.environment import PufferEnv
from gpudrive import GPU_DRIVE_DATA_DIR
Expand Down Expand Up @@ -66,6 +64,7 @@ def __init__(
render_format="mp4",
render_fps=15,
zoom_radius=50,
render_backend="array",
buf=None,
**kwargs,
):
Expand Down Expand Up @@ -97,6 +96,7 @@ def __init__(
self.render_format = render_format
self.render_fps = render_fps
self.zoom_radius = zoom_radius
self.render_backend = render_backend

# VBD
self.vbd_model_path = vbd_model_path
Expand Down Expand Up @@ -422,19 +422,32 @@ def render_env(self):
np.where(np.array(list(self.rendering_in_progress.values())))[0]
)
time_steps = list(self.episode_lengths[envs_to_render, 0])

if len(envs_to_render) > 0:
sim_state_figures = self.env.vis.plot_simulator_state(
env_indices=envs_to_render,
time_steps=time_steps,
zoom_radius=self.zoom_radius,
)
if self.render_backend == "matplotlib":
# Render bird's eye view using matplotlib
sim_state_figures = self.env.vis.plot_simulator_state(
env_indices=envs_to_render,
time_steps=time_steps,
zoom_radius=self.zoom_radius,
)

for idx, render_env_idx in enumerate(envs_to_render):
for idx, render_env_idx in enumerate(envs_to_render):
self.frames[render_env_idx].append(
img_from_fig(sim_state_figures[idx])
)
else:
# Render bird's eye view using raster scan algorithm
bev = self.env.sim.bev_observation_tensor().to_torch()
# Convert the BEV observation to a colored segmentation map
# If the tensor is one-hot encoded segmentation data
colored_bev = color_onehot_segmentation_map(bev, 'cpu')
agent_idx = 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is always the sdc_index right? (I think so if I am following json_serialization.hpp well.
Maybe you would call this variable sdc_index just to it is clear we are logging that agent?

Copy link
Copy Markdown
Contributor Author

@daphne-cornelisse daphne-cornelisse Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the 0th index is always the sdc (we follow json_serialization.hpp too)

Note: The implementation doesn't work yet: the raster image I get is black and white. Also, do you think it would be possible to display the raster image in the case when we control multiple agents?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to permute the bev map so that the number of classes is the first dimension
bev.permute(2, 0, 1) or something like this.

For your second question, if im understanding correctly, I dont see an issue, you just might have to modify the colorization function to accept a first batch dimesion.


self.frames[render_env_idx].append(
img_from_fig(sim_state_figures[idx])
colored_bev[agent_idx, :]
)

def resample_scenario_batch(self):
"""Sample and set new batch of WOMD scenarios."""

Expand Down
30 changes: 30 additions & 0 deletions gpudrive/visualize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,36 @@

from gpudrive.visualize.color import ROAD_GRAPH_COLORS, ROAD_GRAPH_TYPE_NAMES

def color_onehot_segmentation_map(onehot_map, device):
"""
Args:
onehot_map: torch.Tensor of shape [num_classes, H, W], dtype=bool or float (one-hot)
device: torch device
Returns:
colored_image: [H, W, 3] uint8 image
"""
color_mapping = torch.tensor([
[0, 0, 0], # None
[125, 125, 125], # RoadEdge
[120, 120, 120], # RoadLine
[230, 230, 230], # RoadLane
[200, 200, 200], # CrossWalk
[217, 166, 33], # SpeedBump
[255, 0, 0], # StopSign
[0, 255, 255], # Vehicle
[0, 255, 0], # Pedestrian
[128, 0, 128], # Cyclist
[192, 192, 192], # Padding
], dtype=torch.uint8, device=device)

# Convert one-hot to class indices: [H, W]
class_map = onehot_map.argmax(dim=0)

# Map to color image: [H, W, 3]
colored_image = color_mapping[class_map]

return colored_image

def img_from_fig(fig: matplotlib.figure.Figure) -> np.ndarray:
"""Returns a [H, W, 3] uint8 np image from fig.canvas.tostring_rgb()."""
# Adjusted margins to better accommodate 3D plots
Expand Down