Skip to content

Commit 28de5b4

Browse files
committed
Add agent getting to nuscenes
1 parent e75d548 commit 28de5b4

2 files changed

Lines changed: 22 additions & 5 deletions

File tree

avapi/_dataset.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def check_frame(self, frame):
134134
def get_agents(self, frame: int) -> "DataContainer":
135135
return self._load_agents(frame)
136136

137-
def get_agent(self, frame: int, agent: int):
137+
def get_agent(self, frame: int, agent: int) -> List[VehicleState]:
138138
agents = self.get_agents(frame)
139139
return [ag for ag in agents if ag.ID == agent][0]
140140

@@ -158,6 +158,9 @@ def get_sensor_name(self, sensor, agent=None) -> str:
158158
return sensor
159159
else:
160160
return self.sensors[sensor]
161+
162+
def get_sensor_names_by_type(self, sensor_type: str, agent=None) -> List[str]:
163+
return self._load_sensor_names_by_type(sensor_type=sensor_type, agent=agent)
161164

162165
def get_frames(self, sensor, agent=None) -> List[int]:
163166
sensor = self.get_sensor_name(sensor, agent=agent)
@@ -344,6 +347,9 @@ def save_objects(self, frame, objects, folder, file=None):
344347
os.makedirs(folder)
345348
self._save_objects(frame, objects, folder, file=file)
346349

350+
def _load_sensor_names_by_type(self, sensor_type, agent):
351+
raise NotImplementedError
352+
347353
def _load_agents(self, frame):
348354
raise NotImplementedError
349355

@@ -371,8 +377,8 @@ def _load_lidar(self, frame, sensor, agent=None):
371377
def _load_objects(self, frame, sensor, agent=None):
372378
raise NotImplementedError
373379

374-
def _load_sensor_data_filepath(self, frame, sensor: str):
375-
return self._get_sensor_file_name(frame, sensor)
380+
def _load_sensor_data_filepath(self, frame, sensor: str, agent=None):
381+
return self._get_sensor_file_name(frame, sensor, agent=agent)
376382

377383
def _load_objects_from_file(
378384
self,
@@ -752,7 +758,7 @@ def _load_image(self, frame, sensor=None, **kwargs):
752758
img_fname = self._get_sensor_file_name(frame, sensor)
753759
return imread(img_fname)
754760

755-
def _load_ego(self, frame, **kwargs):
761+
def _load_ego(self, frame, **kwargs) -> VehicleState:
756762
ref = GlobalOrigin3D
757763
if self.vehicle_pose is not None:
758764
try:

avapi/nuscenes/dataset.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
from avstack.config import DATASETS
8+
from avstack.environment.objects import VehicleState
89
from avstack.geometry.transformations import matrix_cartesian_to_spherical
910
from scipy.interpolate import interp1d
1011

@@ -164,9 +165,19 @@ def make_sample_records(self):
164165
)
165166
self.t0 = self.sample_records[0]["timestamp"] / 1e6
166167

167-
def get_agents(self, frame: int) -> List:
168+
def get_agents(self, frame: int) -> List[VehicleState]:
168169
return [self._load_ego(frame=frame)]
169170

171+
def _load_sensor_names_by_type(self, sensor_type, **kwargs):
172+
if sensor_type.lower() == "camera":
173+
return [sID for sID in self.sensor_IDs if "CAM" in sID]
174+
elif sensor_type.lower() == "lidar":
175+
return [sID for sID in self.sensor_IDs if "LIDAR" in sID]
176+
elif sensor_type.lower() == "radar":
177+
return [sID for sID in self.sensor_IDs if "RADAR" in sID]
178+
else:
179+
raise NotImplementedError(sensor_type)
180+
170181
def _load_lidar(
171182
self,
172183
frame,

0 commit comments

Comments
 (0)