@@ -96,9 +96,9 @@ def __init__(
9696 self .prediction_kind = prediction_kind
9797 self .data_limit = data_limit
9898 self .label_filter = label_filter
99- assert (balance_after_filter is not None ) or (
100- self . label_filter is None
101- ), "Filter balancing requires a filter"
99+ assert (balance_after_filter is not None ) or (self . label_filter is None ), (
100+ "Filter balancing requires a filter"
101+ )
102102 self .balance_after_filter = balance_after_filter
103103 self .num_workers = num_workers
104104 self .persistent_workers : bool = bool (persistent_workers )
@@ -108,13 +108,13 @@ def __init__(
108108 self .use_inner_cross_validation = (
109109 inner_k_folds > 1
110110 ) # only use cv if there are at least 2 folds
111- assert (
112- fold_index is None or self . use_inner_cross_validation is not None
113- ), "fold_index can only be set if cross validation is used"
111+ assert fold_index is None or self . use_inner_cross_validation is not None , (
112+ " fold_index can only be set if cross validation is used"
113+ )
114114 if fold_index is not None and self .inner_k_folds is not None :
115- assert (
116- fold_index < self . inner_k_folds
117- ), "fold_index can't be larger than the total number of folds"
115+ assert fold_index < self . inner_k_folds , (
116+ " fold_index can't be larger than the total number of folds"
117+ )
118118 self .fold_index = fold_index
119119 self ._base_dir = base_dir
120120 self .n_token_limit = n_token_limit
@@ -137,9 +137,9 @@ def num_of_labels(self):
137137
138138 @property
139139 def feature_vector_size (self ):
140- assert (
141- self . _feature_vector_size is not None
142- ), "size of feature vector must be set"
140+ assert self . _feature_vector_size is not None , (
141+ "size of feature vector must be set"
142+ )
143143 return self ._feature_vector_size
144144
145145 @property
@@ -619,6 +619,19 @@ def raw_file_names_dict(self) -> dict:
619619 """
620620 raise NotImplementedError
621621
622+ @property
623+ def classes_txt_file_path (self ) -> str :
624+ """
625+ Returns the filename for the classes text file.
626+
627+ Returns:
628+ str: The filename for the classes text file.
629+ """
630+ # This property also used in following places:
631+ # - results/prediction.py: to load class names for csv columns names
632+ # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
633+ return os .path .join (self .processed_dir_main , "classes.txt" )
634+
622635
623636class MergedDataset (XYBaseDataModule ):
624637 MERGED = []
@@ -1373,14 +1386,3 @@ def processed_file_names_dict(self) -> dict:
13731386 if self .n_token_limit is not None :
13741387 return {"data" : f"data_maxlen{ self .n_token_limit } .pt" }
13751388 return {"data" : "data.pt" }
1376-
1377- @property
1378- def classes_txt_file_path (self ) -> str :
1379- """
1380- Returns the filename for the classes text file.
1381-
1382- Returns:
1383- str: The filename for the classes text file.
1384- """
1385- # This property also used in custom trainer `chebai/trainer/CustomTrainer.py`
1386- return os .path .join (self .processed_dir_main , "classes.txt" )
0 commit comments