33
44import typer
55
6- from nuplan .database .nuplan_db .nuplandb import NuPlanDB
7- from nuplan .database .nuplan_db .scenario_tag import ScenarioTag
6+ from nuplan .database .nuplan_db .db_cli_queries import (
7+ get_db_description ,
8+ get_db_duration_in_us ,
9+ get_db_log_duration ,
10+ get_db_log_vehicles ,
11+ get_db_scenario_info ,
12+ )
13+ from nuplan .planning .scenario_builder .nuplan_db .nuplan_scenario_utils import download_file_if_necessary
814
915cli = typer .Typer ()
1016
1117NUPLAN_DATA_ROOT = os .getenv ('NUPLAN_DATA_ROOT' , "/data/sets/nuplan/" )
1218NUPLAN_DB_VERSION = f'{ NUPLAN_DATA_ROOT } /nuplan-v1.0/mini/2021.07.16.20.45.29_veh-35_01095_01486.db'
1319
1420
21+ def _ensure_file_downloaded (data_root : str , potentially_remote_path : str ) -> str :
22+ """
23+ Attempts to download the DB file from a remote URL if it does not exist locally.
24+ If the download fails, an error will be raised.
25+ :param data_root: The location to download the file, if necessary.
26+ :param potentially_remote_path: The path to the file.
27+ :return: The resulting file path. Will be one of a few options:
28+ * If potentially_remote_path points to a local file, will return potentially_remote_path
29+ * If potentially_remote_file points to a remote file, it does not exist currently, and the file can be successfully downloaded, it will return the path of the downloaded file.
30+ * In all other cases, an error will be raised.
31+ """
32+ output_file_path : str = download_file_if_necessary (data_root , potentially_remote_path )
33+
34+ if not os .path .exists (output_file_path ):
35+ raise ValueError (f"{ potentially_remote_path } could not be downloaded." )
36+
37+ return output_file_path
38+
39+
1540@cli .command ()
1641def info (
1742 db_version : str = typer .Argument (NUPLAN_DB_VERSION , help = "The database version." ),
@@ -20,11 +45,24 @@ def info(
2045 """
2146 Print out detailed information about the selected database.
2247 """
23- # Construct database
24- db = NuPlanDB (load_path = db_version , data_root = data_root )
25- # Use the default __str__
26- typer .echo ("DB info" )
27- typer .echo (db )
48+ db_version = _ensure_file_downloaded (data_root , db_version )
49+ db_description = get_db_description (db_version )
50+
51+ for table_name , table_description in db_description .tables .items ():
52+ typer .echo (f"Table { table_name } : { table_description .row_count } rows" )
53+
54+ for column_name , column_description in table_description .columns .items ():
55+ typer .echo (
56+ "" .join (
57+ [
58+ f"\t column { column_name } : { column_description .data_type } " ,
59+ "NULL " if column_description .nullable else "NOT NULL " ,
60+ "PRIMARY KEY " if column_description .is_primary_key else "" ,
61+ ]
62+ )
63+ )
64+
65+ typer .echo ()
2866
2967
3068@cli .command ()
@@ -35,16 +73,11 @@ def duration(
3573 """
3674 Print out the duration of the selected db.
3775 """
38- # Construct database
39- db = NuPlanDB (load_path = db_version , data_root = data_root )
40-
41- # Approximate the duration of db by dividing the number of lidar_pc and the frequency of the DB
42- assumed_db_frequency = 20
43- db_duration_s = len (db .lidar_pc ) / assumed_db_frequency
76+ db_version = _ensure_file_downloaded (data_root , db_version )
77+ db_duration_us = get_db_duration_in_us (db_version )
78+ db_duration_s = float (db_duration_us ) / 1e6
4479 db_duration_str = time .strftime ("%H:%M:%S" , time .gmtime (db_duration_s ))
45- typer .echo (
46- f"DB approximate duration (assuming db frequency { assumed_db_frequency } Hz) is { db_duration_str } [HH:MM:SS]"
47- )
80+ typer .echo (f"DB duration is { db_duration_str } [HH:MM:SS]" )
4881
4982
5083@cli .command ()
@@ -55,20 +88,15 @@ def log_duration(
5588 """
5689 Print out the duration of every log in the selected db.
5790 """
58- # Construct database
59- db = NuPlanDB (load_path = db_version , data_root = data_root )
60-
61- # Approximate the duration of db by dividing the number of lidar_pc and the frequency of the DB
62- assumed_db_frequency = 20
91+ db_version = _ensure_file_downloaded (data_root , db_version )
92+ num_logs = 0
93+ for log_file_name , log_file_duration_us in get_db_log_duration (db_version ):
94+ log_file_duration_s = float (log_file_duration_us ) / 1e6
95+ log_file_duration_str = time .strftime ("%H:%M:%S" , time .gmtime (log_file_duration_s ))
96+ typer .echo (f"The duration of log { log_file_name } is { log_file_duration_str } [HH:MM:SS]" )
97+ num_logs += 1
6398
64- # Print out for every log the approximate durations
65- typer .echo (f"The DB: { db .name } contains { len (db .log )} logs" )
66-
67- for log in db .log :
68- lidar_pcs = [lidar for scene in log .scenes for lidar in scene .lidar_pcs ]
69- db_duration_s = len (lidar_pcs ) / assumed_db_frequency
70- db_duration_str = time .strftime ("%H:%M:%S" , time .gmtime (db_duration_s ))
71- typer .echo (f"\t The approximate duration of log { log .logfile } is { db_duration_str } [HH:MM:SS]" )
99+ typer .echo (f"There are { num_logs } total logs." )
72100
73101
74102@cli .command ()
@@ -79,14 +107,9 @@ def log_vehicle(
79107 """
80108 Print out vehicle information from every log in the selected database.
81109 """
82- # Construct database
83- db = NuPlanDB (load_path = db_version , data_root = data_root )
84-
85- # Print out for every log the used vehicle
86- typer .echo ("The used vehicles for every log follow:" )
87-
88- for log in db .log :
89- typer .echo (f"\t For the log { log .logfile } vehicle { log .vehicle_name } of type { log .vehicle_type } was used" )
110+ db_version = _ensure_file_downloaded (data_root , db_version )
111+ for log_file , vehicle_name in get_db_log_vehicles (db_version ):
112+ typer .echo (f"For the log { log_file } , vehicle { vehicle_name } was used." )
90113
91114
92115@cli .command ()
@@ -97,21 +120,13 @@ def scenarios(
97120 """
98121 Print out the available scenarios in the selected db.
99122 """
100- # Construct database
101- db = NuPlanDB (load_path = db_version , data_root = data_root )
102-
103- # Read all available tags:
104- available_types = [tag [0 ] for tag in db .session .query (ScenarioTag .type ).distinct ().all ()]
105-
106- # Tag table
107- tag_table = db .scenario_tag
108-
109- # Print out the available scenarios
110- typer .echo (f"The available scenario tags from db: { db_version } follow, in total { len (available_types )} scenarios" )
123+ db_version = _ensure_file_downloaded (data_root , db_version )
124+ total_count = 0
125+ for tag , num_scenarios in get_db_scenario_info (db_version ):
126+ typer .echo (f"{ tag } : { num_scenarios } scenarios." )
127+ total_count += num_scenarios
111128
112- for tag in available_types :
113- tags = tag_table .select_many (type = tag )
114- typer .echo (f"\t - { tag } has { len (tags )} scenarios" )
129+ typer .echo (f"TOTAL: { total_count } scenarios." )
115130
116131
117132if __name__ == '__main__' :
0 commit comments