@@ -96,9 +96,9 @@ def __init__(
9696 self .prediction_kind = prediction_kind
9797 self .data_limit = data_limit
9898 self .label_filter = label_filter
99- assert (balance_after_filter is not None ) or (
100- self . label_filter is None
101- ), "Filter balancing requires a filter"
99+ assert (balance_after_filter is not None ) or (self . label_filter is None ), (
100+ "Filter balancing requires a filter"
101+ )
102102 self .balance_after_filter = balance_after_filter
103103 self .num_workers = num_workers
104104 self .persistent_workers : bool = bool (persistent_workers )
@@ -108,13 +108,13 @@ def __init__(
108108 self .use_inner_cross_validation = (
109109 inner_k_folds > 1
110110 ) # only use cv if there are at least 2 folds
111- assert (
112- fold_index is None or self . use_inner_cross_validation is not None
113- ), "fold_index can only be set if cross validation is used"
111+ assert fold_index is None or self . use_inner_cross_validation is not None , (
112+ " fold_index can only be set if cross validation is used"
113+ )
114114 if fold_index is not None and self .inner_k_folds is not None :
115- assert (
116- fold_index < self . inner_k_folds
117- ), "fold_index can't be larger than the total number of folds"
115+ assert fold_index < self . inner_k_folds , (
116+ " fold_index can't be larger than the total number of folds"
117+ )
118118 self .fold_index = fold_index
119119 self ._base_dir = base_dir
120120 self .n_token_limit = n_token_limit
@@ -137,9 +137,9 @@ def num_of_labels(self):
137137
138138 @property
139139 def feature_vector_size (self ):
140- assert (
141- self . _feature_vector_size is not None
142- ), "size of feature vector must be set"
140+ assert self . _feature_vector_size is not None , (
141+ "size of feature vector must be set"
142+ )
143143 return self ._feature_vector_size
144144
145145 @property
@@ -1242,9 +1242,7 @@ def _retrieve_splits_from_csv(self) -> None:
12421242 splits_df = pd .read_csv (self .splits_file_path )
12431243
12441244 filename = self .processed_file_names_dict ["data" ]
1245- data = self .load_processed_data_from_file (
1246- os .path .join (self .processed_dir , filename )
1247- )
1245+ data = self .load_processed_data_from_file (filename )
12481246 df_data = pd .DataFrame (data )
12491247
12501248 if self .apply_id_filter :
@@ -1325,7 +1323,9 @@ def load_processed_data(
13251323 return self .load_processed_data_from_file (filename )
13261324
13271325 def load_processed_data_from_file (self , filename ):
1328- return torch .load (os .path .join (filename ), weights_only = False )
1326+ return torch .load (
1327+ os .path .join (self .processed_dir , filename ), weights_only = False
1328+ )
13291329
13301330 # ------------------------------ Phase: Raw Properties -----------------------------------
13311331 @property
0 commit comments