Skip to content

Commit 242ade0

Browse files
committed
use chebi utils library for dataset preparation
1 parent ab872a2 commit 242ade0

File tree

3 files changed

+53
-481
lines changed

3 files changed

+53
-481
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 1 addition & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
940940
def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
941941
"""
942942
Converts the graph to a raw dataset.
943-
Uses the graph created by `_extract_class_hierarchy` method to extract the
943+
Uses the graph created by chebi_utils to extract the
944944
raw data in Dataframe format with additional columns corresponding to each multi-label class.
945945
946946
Args:
@@ -951,21 +951,6 @@ def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
951951
"""
952952
pass
953953

954-
@abstractmethod
955-
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
956-
"""
957-
Selects classes from the dataset based on a specified criteria.
958-
959-
Args:
960-
g (nx.Graph): The graph representing the dataset.
961-
*args: Additional positional arguments.
962-
**kwargs: Additional keyword arguments.
963-
964-
Returns:
965-
List: A sorted list of node IDs that meet the specified criteria.
966-
"""
967-
pass
968-
969954
def save_processed(self, data: pd.DataFrame, filename: str) -> None:
970955
"""
971956
Save the processed dataset to a pickle file.
@@ -1123,120 +1108,6 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
11231108
"""
11241109
pass
11251110

1126-
def get_test_split(
1127-
self, df: pd.DataFrame, seed: Optional[int] = None
1128-
) -> Tuple[pd.DataFrame, pd.DataFrame]:
1129-
"""
1130-
Split the input DataFrame into training and testing sets based on multilabel stratified sampling.
1131-
1132-
This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
1133-
in the training and testing sets is approximately the same. The split is based on the "labels" column
1134-
in the DataFrame.
1135-
1136-
Args:
1137-
df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column
1138-
named "labels" with the multilabel data.
1139-
seed (int, optional): The random seed to be used for reproducibility. Default is None.
1140-
1141-
Returns:
1142-
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames.
1143-
1144-
Raises:
1145-
ValueError: If the DataFrame does not contain a column named "labels".
1146-
"""
1147-
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
1148-
from sklearn.model_selection import StratifiedShuffleSplit
1149-
1150-
print("Get test data split")
1151-
1152-
labels_list = df["labels"].tolist()
1153-
1154-
if len(labels_list[0]) > 1:
1155-
splitter = MultilabelStratifiedShuffleSplit(
1156-
n_splits=1, test_size=self.test_split, random_state=seed
1157-
)
1158-
else:
1159-
splitter = StratifiedShuffleSplit(
1160-
n_splits=1, test_size=self.test_split, random_state=seed
1161-
)
1162-
1163-
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
1164-
1165-
df_train = df.iloc[train_indices]
1166-
df_test = df.iloc[test_indices]
1167-
return df_train, df_test
1168-
1169-
def get_train_val_splits_given_test(
1170-
self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None
1171-
) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]:
1172-
"""
1173-
Split the dataset into train and validation sets, given a test set.
1174-
Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap
1175-
1176-
Args:
1177-
df (pd.DataFrame): The original dataset.
1178-
test_df (pd.DataFrame): The test dataset.
1179-
seed (int, optional): The random seed to be used for reproducibility. Default is None.
1180-
1181-
Returns:
1182-
Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and
1183-
validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train
1184-
and validation DataFrames. The keys are the names of the train and validation sets, and the values
1185-
are the corresponding DataFrames.
1186-
"""
1187-
from iterstrat.ml_stratifiers import (
1188-
MultilabelStratifiedKFold,
1189-
MultilabelStratifiedShuffleSplit,
1190-
)
1191-
from sklearn.model_selection import StratifiedShuffleSplit
1192-
1193-
print("Split dataset into train / val with given test set")
1194-
1195-
test_ids = test_df["ident"].tolist()
1196-
df_trainval = df[~df["ident"].isin(test_ids)]
1197-
labels_list_trainval = df_trainval["labels"].tolist()
1198-
1199-
if self.use_inner_cross_validation:
1200-
folds = {}
1201-
kfold = MultilabelStratifiedKFold(
1202-
n_splits=self.inner_k_folds, random_state=seed
1203-
)
1204-
for fold, (train_ids, val_ids) in enumerate(
1205-
kfold.split(
1206-
labels_list_trainval,
1207-
labels_list_trainval,
1208-
)
1209-
):
1210-
df_validation = df_trainval.iloc[val_ids]
1211-
df_train = df_trainval.iloc[train_ids]
1212-
folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train
1213-
folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = (
1214-
df_validation
1215-
)
1216-
1217-
return folds
1218-
1219-
if len(labels_list_trainval[0]) > 1:
1220-
splitter = MultilabelStratifiedShuffleSplit(
1221-
n_splits=1,
1222-
test_size=self.validation_split / (1 - self.test_split),
1223-
random_state=seed,
1224-
)
1225-
else:
1226-
splitter = StratifiedShuffleSplit(
1227-
n_splits=1,
1228-
test_size=self.validation_split / (1 - self.test_split),
1229-
random_state=seed,
1230-
)
1231-
1232-
train_indices, validation_indices = next(
1233-
splitter.split(labels_list_trainval, labels_list_trainval)
1234-
)
1235-
1236-
df_validation = df_trainval.iloc[validation_indices]
1237-
df_train = df_trainval.iloc[train_indices]
1238-
return df_train, df_validation
1239-
12401111
def _retrieve_splits_from_csv(self) -> None:
12411112
"""
12421113
Retrieve previously saved data splits from splits.csv file or from provided file path.

0 commit comments

Comments
 (0)