@@ -1173,9 +1173,7 @@ def _retrieve_splits_from_csv(self) -> None:
11731173 splits_df = pd .read_csv (self .splits_file_path )
11741174
11751175 filename = self .processed_file_names_dict ["data" ]
1176- data = self .load_processed_data_from_file (
1177- os .path .join (self .processed_dir , filename )
1178- )
1176+ data = self .load_processed_data_from_file (filename )
11791177 df_data = pd .DataFrame (data )
11801178
11811179 if self .apply_id_filter :
@@ -1254,8 +1252,23 @@ def load_processed_data(
12541252 # If filename is provided
12551253 return self .load_processed_data_from_file (filename )
12561254
1257- def load_processed_data_from_file (self , filename ):
1258- return torch .load (os .path .join (filename ), weights_only = False )
1255+ def load_processed_data_from_file (self , filename : str ) -> list [dict [str , Any ]]:
1256+ """Load processed data from a file.
1257+
1258+ The full path is not required; only the filename is needed, as it will be joined with the processed directory.
1259+
1260+ Args:
1261+ filename (str): The name of the file to load the processed data from.
1262+
1263+ Returns:
1264+ List[Dict[str, Any]]: The loaded processed data.
1265+
1266+ Example:
1267+ data = self.load_processed_data_from_file('data.pt')
1268+ """
1269+ return torch .load (
1270+ os .path .join (self .processed_dir , filename ), weights_only = False
1271+ )
12591272
12601273 # ------------------------------ Phase: Raw Properties -----------------------------------
12611274 @property
0 commit comments