Skip to content

Commit b6d4cbc

Browse files
Copilotsfluegel05
andcommitted
Add create_multilabel_splits with iterstrat/sklearn for multilabel stratified splits
Co-authored-by: sfluegel05 <43573433+sfluegel05@users.noreply.github.com>
1 parent da7c433 commit b6d4cbc

File tree

4 files changed

+201
-2
lines changed

4 files changed

+201
-2
lines changed

chebi_utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf
22
from chebi_utils.obo_extractor import build_chebi_graph
33
from chebi_utils.sdf_extractor import extract_molecules
4-
from chebi_utils.splitter import create_splits
4+
from chebi_utils.splitter import create_multilabel_splits, create_splits
55

66
__all__ = [
77
"download_chebi_obo",
88
"download_chebi_sdf",
99
"build_chebi_graph",
1010
"extract_molecules",
1111
"create_splits",
12+
"create_multilabel_splits",
1213
]

chebi_utils/splitter.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,116 @@
66
import pandas as pd
77

88

9+
def create_multilabel_splits(
10+
df: pd.DataFrame,
11+
labels_col: str,
12+
train_ratio: float = 0.8,
13+
val_ratio: float = 0.1,
14+
test_ratio: float = 0.1,
15+
seed: int | None = 42,
16+
) -> dict[str, pd.DataFrame]:
17+
"""Create stratified train/validation/test splits for multilabel DataFrames.
18+
19+
Automatically detects whether the dataset is multilabel (each entry has
20+
more than one label) or single-label, and applies the appropriate
21+
stratification strategy:
22+
23+
- Multilabel: uses ``MultilabelStratifiedShuffleSplit`` from the
24+
``iterative-stratification`` package.
25+
- Single-label: uses ``StratifiedShuffleSplit`` from ``scikit-learn``.
26+
27+
Parameters
28+
----------
29+
df : pd.DataFrame
30+
Input data to split. Must contain a column ``labels_col`` whose
31+
values are sequences of labels (e.g. lists of strings or ints).
32+
labels_col : str
33+
Name of the column that contains the label sequences.
34+
train_ratio : float
35+
Fraction of data for training (default 0.8).
36+
val_ratio : float
37+
Fraction of data for validation (default 0.1).
38+
test_ratio : float
39+
Fraction of data for testing (default 0.1).
40+
seed : int or None
41+
Random seed for reproducibility.
42+
43+
Returns
44+
-------
45+
dict
46+
Dictionary with keys ``'train'``, ``'val'``, ``'test'``, each
47+
containing a DataFrame.
48+
49+
Raises
50+
------
51+
ValueError
52+
If the ratios do not sum to 1, any ratio is outside ``[0, 1]``, or
53+
``labels_col`` is not found in *df*.
54+
"""
55+
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
56+
raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0")
57+
if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]):
58+
raise ValueError("All ratios must be between 0 and 1")
59+
if labels_col not in df.columns:
60+
raise ValueError(f"Column '{labels_col}' not found in DataFrame")
61+
62+
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
63+
from sklearn.model_selection import StratifiedShuffleSplit
64+
from sklearn.preprocessing import MultiLabelBinarizer
65+
66+
labels_list: list[list] = df[labels_col].tolist()
67+
is_multilabel = any(len(lbl) > 1 for lbl in labels_list)
68+
69+
df_reset = df.reset_index(drop=True)
70+
71+
if is_multilabel:
72+
mlb = MultiLabelBinarizer()
73+
labels_matrix = mlb.fit_transform(labels_list)
74+
else:
75+
labels_matrix = [lbl[0] for lbl in labels_list]
76+
77+
# ── Step 1: carve out the test set ──────────────────────────────────────
78+
if is_multilabel:
79+
test_splitter = MultilabelStratifiedShuffleSplit(
80+
n_splits=1, test_size=test_ratio, random_state=seed
81+
)
82+
train_val_idx, test_idx = next(test_splitter.split(labels_matrix, labels_matrix))
83+
else:
84+
test_splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=seed)
85+
train_val_idx, test_idx = next(test_splitter.split(labels_matrix, labels_matrix))
86+
87+
df_test = df_reset.iloc[test_idx]
88+
df_trainval = df_reset.iloc[train_val_idx]
89+
90+
# ── Step 2: split train/val from the remaining data ─────────────────────
91+
labels_trainval = (
92+
labels_matrix[train_val_idx]
93+
if is_multilabel
94+
else [labels_matrix[i] for i in train_val_idx]
95+
)
96+
val_ratio_adjusted = val_ratio / (1.0 - test_ratio)
97+
98+
if is_multilabel:
99+
val_splitter = MultilabelStratifiedShuffleSplit(
100+
n_splits=1, test_size=val_ratio_adjusted, random_state=seed
101+
)
102+
train_idx_inner, val_idx_inner = next(val_splitter.split(labels_trainval, labels_trainval))
103+
else:
104+
val_splitter = StratifiedShuffleSplit(
105+
n_splits=1, test_size=val_ratio_adjusted, random_state=seed
106+
)
107+
train_idx_inner, val_idx_inner = next(val_splitter.split(labels_trainval, labels_trainval))
108+
109+
df_train = df_trainval.iloc[train_idx_inner]
110+
df_val = df_trainval.iloc[val_idx_inner]
111+
112+
return {
113+
"train": df_train.reset_index(drop=True),
114+
"val": df_val.reset_index(drop=True),
115+
"test": df_test.reset_index(drop=True),
116+
}
117+
118+
9119
def create_splits(
10120
df: pd.DataFrame,
11121
train_ratio: float = 0.8,

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ license = { file = "LICENSE" }
1111
requires-python = ">=3.10"
1212
dependencies = [
1313
"fastobo>=0.14",
14+
"iterative-stratification>=0.1.9",
1415
"networkx>=3.0",
1516
"numpy>=1.24",
1617
"pandas>=2.0",
1718
"rdkit>=2022.09",
19+
"scikit-learn>=1.0",
1820
"chembl_structure_pipeline>=1.2.4",
1921
]
2022

tests/test_splitter.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66
import pytest
77

8-
from chebi_utils.splitter import create_splits
8+
from chebi_utils.splitter import create_multilabel_splits, create_splits
99

1010

1111
@pytest.fixture
@@ -103,3 +103,89 @@ def test_stratified_reproducible(self, sample_df):
103103
splits1 = create_splits(sample_df, stratify_col="category", seed=42)
104104
splits2 = create_splits(sample_df, stratify_col="category", seed=42)
105105
pd.testing.assert_frame_equal(splits1["train"], splits2["train"])
106+
107+
108+
@pytest.fixture
109+
def multilabel_df():
110+
"""DataFrame with multilabel 'labels' column (200 rows)."""
111+
all_labels = [["A"], ["B"], ["C"], ["A", "B"], ["A", "C"], ["B", "C"]]
112+
labels = [all_labels[i % len(all_labels)] for i in range(200)]
113+
return pd.DataFrame(
114+
{
115+
"id": [f"CHEBI:{i}" for i in range(200)],
116+
"labels": labels,
117+
}
118+
)
119+
120+
121+
@pytest.fixture
122+
def singlelabel_df():
123+
"""DataFrame with single-label 'labels' column."""
124+
return pd.DataFrame(
125+
{
126+
"id": [f"CHEBI:{i}" for i in range(200)],
127+
"labels": [["A"] if i % 2 == 0 else ["B"] for i in range(200)],
128+
}
129+
)
130+
131+
132+
class TestCreateMultilabelSplits:
133+
def test_returns_three_splits(self, multilabel_df):
134+
splits = create_multilabel_splits(multilabel_df, labels_col="labels")
135+
assert set(splits.keys()) == {"train", "val", "test"}
136+
137+
def test_sizes_sum_to_total(self, multilabel_df):
138+
splits = create_multilabel_splits(multilabel_df, labels_col="labels")
139+
assert sum(len(v) for v in splits.values()) == len(multilabel_df)
140+
141+
def test_no_overlap(self, multilabel_df):
142+
splits = create_multilabel_splits(multilabel_df, labels_col="labels")
143+
train_ids = set(splits["train"]["id"])
144+
val_ids = set(splits["val"]["id"])
145+
test_ids = set(splits["test"]["id"])
146+
assert train_ids.isdisjoint(val_ids)
147+
assert train_ids.isdisjoint(test_ids)
148+
assert val_ids.isdisjoint(test_ids)
149+
150+
def test_all_rows_covered(self, multilabel_df):
151+
splits = create_multilabel_splits(multilabel_df, labels_col="labels")
152+
all_ids = set(splits["train"]["id"]) | set(splits["val"]["id"]) | set(splits["test"]["id"])
153+
assert all_ids == set(multilabel_df["id"])
154+
155+
def test_reproducible_with_same_seed(self, multilabel_df):
156+
splits1 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=7)
157+
splits2 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=7)
158+
pd.testing.assert_frame_equal(splits1["train"], splits2["train"])
159+
160+
def test_different_seeds_give_different_splits(self, multilabel_df):
161+
splits1 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=1)
162+
splits2 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=2)
163+
assert not splits1["train"]["id"].equals(splits2["train"]["id"])
164+
165+
def test_approximate_split_sizes(self, multilabel_df):
166+
splits = create_multilabel_splits(
167+
multilabel_df, labels_col="labels", train_ratio=0.8, val_ratio=0.1, test_ratio=0.1
168+
)
169+
n = len(multilabel_df)
170+
assert abs(len(splits["test"]) - int(n * 0.1)) <= 2
171+
assert abs(len(splits["val"]) - int(n * 0.1)) <= 2
172+
173+
def test_invalid_ratios_raise_error(self, multilabel_df):
174+
with pytest.raises(ValueError, match="must equal 1.0"):
175+
create_multilabel_splits(
176+
multilabel_df, labels_col="labels", train_ratio=0.5, val_ratio=0.3, test_ratio=0.3
177+
)
178+
179+
def test_missing_labels_col_raises_error(self, multilabel_df):
180+
with pytest.raises(ValueError, match="not found in DataFrame"):
181+
create_multilabel_splits(multilabel_df, labels_col="nonexistent")
182+
183+
def test_singlelabel_path(self, singlelabel_df):
184+
"""Single-label lists should use StratifiedShuffleSplit without error."""
185+
splits = create_multilabel_splits(singlelabel_df, labels_col="labels")
186+
assert sum(len(v) for v in splits.values()) == len(singlelabel_df)
187+
train_ids = set(splits["train"]["id"])
188+
val_ids = set(splits["val"]["id"])
189+
test_ids = set(splits["test"]["id"])
190+
assert train_ids.isdisjoint(val_ids)
191+
assert train_ids.isdisjoint(test_ids)

0 commit comments

Comments
 (0)