@@ -940,7 +940,7 @@ def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
940940 def _graph_to_raw_dataset (self , graph : "nx.DiGraph" ) -> pd .DataFrame :
941941 """
942942 Converts the graph to a raw dataset.
943- Uses the graph created by `_extract_class_hierarchy` method to extract the
943+ Uses the graph created by chebi_utils to extract the
944944 raw data in Dataframe format with additional columns corresponding to each multi-label class.
945945
946946 Args:
@@ -951,21 +951,6 @@ def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
951951 """
952952 pass
953953
954- @abstractmethod
955- def select_classes (self , g : "nx.DiGraph" , * args , ** kwargs ) -> List :
956- """
957- Selects classes from the dataset based on a specified criteria.
958-
959- Args:
960- g (nx.Graph): The graph representing the dataset.
961- *args: Additional positional arguments.
962- **kwargs: Additional keyword arguments.
963-
964- Returns:
965- List: A sorted list of node IDs that meet the specified criteria.
966- """
967- pass
968-
969954 def save_processed (self , data : pd .DataFrame , filename : str ) -> None :
970955 """
971956 Save the processed dataset to a pickle file.
@@ -1123,120 +1108,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
11231108 """
11241109 pass
11251110
1126- def get_test_split (
1127- self , df : pd .DataFrame , seed : Optional [int ] = None
1128- ) -> Tuple [pd .DataFrame , pd .DataFrame ]:
1129- """
1130- Split the input DataFrame into training and testing sets based on multilabel stratified sampling.
1131-
1132- This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
1133- in the training and testing sets is approximately the same. The split is based on the "labels" column
1134- in the DataFrame.
1135-
1136- Args:
1137- df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column
1138- named "labels" with the multilabel data.
1139- seed (int, optional): The random seed to be used for reproducibility. Default is None.
1140-
1141- Returns:
1142- Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames.
1143-
1144- Raises:
1145- ValueError: If the DataFrame does not contain a column named "labels".
1146- """
1147- from iterstrat .ml_stratifiers import MultilabelStratifiedShuffleSplit
1148- from sklearn .model_selection import StratifiedShuffleSplit
1149-
1150- print ("Get test data split" )
1151-
1152- labels_list = df ["labels" ].tolist ()
1153-
1154- if len (labels_list [0 ]) > 1 :
1155- splitter = MultilabelStratifiedShuffleSplit (
1156- n_splits = 1 , test_size = self .test_split , random_state = seed
1157- )
1158- else :
1159- splitter = StratifiedShuffleSplit (
1160- n_splits = 1 , test_size = self .test_split , random_state = seed
1161- )
1162-
1163- train_indices , test_indices = next (splitter .split (labels_list , labels_list ))
1164-
1165- df_train = df .iloc [train_indices ]
1166- df_test = df .iloc [test_indices ]
1167- return df_train , df_test
1168-
1169- def get_train_val_splits_given_test (
1170- self , df : pd .DataFrame , test_df : pd .DataFrame , seed : int = None
1171- ) -> Union [Dict [str , pd .DataFrame ], Tuple [pd .DataFrame , pd .DataFrame ]]:
1172- """
1173- Split the dataset into train and validation sets, given a test set.
1174- Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap
1175-
1176- Args:
1177- df (pd.DataFrame): The original dataset.
1178- test_df (pd.DataFrame): The test dataset.
1179- seed (int, optional): The random seed to be used for reproducibility. Default is None.
1180-
1181- Returns:
1182- Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and
1183- validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train
1184- and validation DataFrames. The keys are the names of the train and validation sets, and the values
1185- are the corresponding DataFrames.
1186- """
1187- from iterstrat .ml_stratifiers import (
1188- MultilabelStratifiedKFold ,
1189- MultilabelStratifiedShuffleSplit ,
1190- )
1191- from sklearn .model_selection import StratifiedShuffleSplit
1192-
1193- print ("Split dataset into train / val with given test set" )
1194-
1195- test_ids = test_df ["ident" ].tolist ()
1196- df_trainval = df [~ df ["ident" ].isin (test_ids )]
1197- labels_list_trainval = df_trainval ["labels" ].tolist ()
1198-
1199- if self .use_inner_cross_validation :
1200- folds = {}
1201- kfold = MultilabelStratifiedKFold (
1202- n_splits = self .inner_k_folds , random_state = seed
1203- )
1204- for fold , (train_ids , val_ids ) in enumerate (
1205- kfold .split (
1206- labels_list_trainval ,
1207- labels_list_trainval ,
1208- )
1209- ):
1210- df_validation = df_trainval .iloc [val_ids ]
1211- df_train = df_trainval .iloc [train_ids ]
1212- folds [self .raw_file_names_dict [f"fold_{ fold } _train" ]] = df_train
1213- folds [self .raw_file_names_dict [f"fold_{ fold } _validation" ]] = (
1214- df_validation
1215- )
1216-
1217- return folds
1218-
1219- if len (labels_list_trainval [0 ]) > 1 :
1220- splitter = MultilabelStratifiedShuffleSplit (
1221- n_splits = 1 ,
1222- test_size = self .validation_split / (1 - self .test_split ),
1223- random_state = seed ,
1224- )
1225- else :
1226- splitter = StratifiedShuffleSplit (
1227- n_splits = 1 ,
1228- test_size = self .validation_split / (1 - self .test_split ),
1229- random_state = seed ,
1230- )
1231-
1232- train_indices , validation_indices = next (
1233- splitter .split (labels_list_trainval , labels_list_trainval )
1234- )
1235-
1236- df_validation = df_trainval .iloc [validation_indices ]
1237- df_train = df_trainval .iloc [train_indices ]
1238- return df_train , df_validation
1239-
12401111 def _retrieve_splits_from_csv (self ) -> None :
12411112 """
12421113 Retrieve previously saved data splits from splits.csv file or from provided file path.
0 commit comments