Skip to content

Commit 7a4e4d6

Browse files
authored
Merge pull request #4 from ChEB-AI/copilot/add-stratified-splits-dataset
Add multilabel stratified train/val/test splits
2 parents 1da984f + 2381201 commit 7a4e4d6

File tree

8 files changed

+188
-130
lines changed

8 files changed

+188
-130
lines changed

chebi_utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf
33
from chebi_utils.obo_extractor import build_chebi_graph
44
from chebi_utils.sdf_extractor import extract_molecules
5-
from chebi_utils.splitter import create_splits
5+
from chebi_utils.splitter import create_multilabel_splits
66

77
__all__ = [
88
"build_labeled_dataset",
99
"download_chebi_obo",
1010
"download_chebi_sdf",
1111
"build_chebi_graph",
1212
"extract_molecules",
13-
"create_splits",
13+
"create_multilabel_splits",
1414
]

chebi_utils/obo_extractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ def build_chebi_graph(filepath: str | Path) -> nx.DiGraph:
116116
return graph
117117

118118

119-
def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.Graph:
120-
"""Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a). Also removes nodes that are not connected by any is_a edges to other nodes."""
119+
def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.DiGraph:
120+
"""Subgraph of ChEBI including only edges corresponding to hierarchical relations (is_a).
121+
Also removes nodes that are not connected by any is_a edges to other nodes."""
121122
return chebi_graph.edge_subgraph(
122123
(u, v) for u, v, d in chebi_graph.edges(data=True) if d.get("relation") == "is_a"
123124
)

chebi_utils/splitter.py

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,43 @@
22

33
from __future__ import annotations
44

5-
import numpy as np
65
import pandas as pd
76

87

9-
def create_splits(
8+
def create_multilabel_splits(
109
df: pd.DataFrame,
10+
label_start_col: int = 2,
1111
train_ratio: float = 0.8,
1212
val_ratio: float = 0.1,
1313
test_ratio: float = 0.1,
14-
stratify_col: str | None = None,
15-
seed: int = 42,
14+
seed: int | None = 42,
1615
) -> dict[str, pd.DataFrame]:
17-
"""Create stratified train/validation/test splits of a DataFrame.
16+
"""Create stratified train/validation/test splits for multilabel DataFrames.
17+
18+
Columns from index *label_start_col* onwards are treated as binary label
19+
columns (one boolean column per label). The stratification strategy is
20+
chosen automatically based on the number of label columns:
21+
22+
- More than one label column: ``MultilabelStratifiedShuffleSplit`` from
23+
the ``iterative-stratification`` package.
24+
- Single label column: ``StratifiedShuffleSplit`` from ``scikit-learn``.
1825
1926
Parameters
2027
----------
2128
df : pd.DataFrame
22-
Input data to split.
29+
Input data. Columns ``0`` to ``label_start_col - 1`` are treated as
30+
feature/metadata columns; all remaining columns are boolean label
31+
columns. A typical ChEBI DataFrame has columns
32+
``["chebi_id", "mol", "label1", "label2", ...]``.
33+
label_start_col : int
34+
Index of the first label column (default 2).
2335
train_ratio : float
2436
Fraction of data for training (default 0.8).
2537
val_ratio : float
2638
Fraction of data for validation (default 0.1).
2739
test_ratio : float
2840
Fraction of data for testing (default 0.1).
29-
stratify_col : str or None
30-
Column name to use for stratification. If None, splits are random.
31-
seed : int
41+
seed : int or None
3242
Random seed for reproducibility.
3343
3444
Returns
@@ -40,44 +50,60 @@ def create_splits(
4050
Raises
4151
------
4252
ValueError
43-
If the ratios do not sum to 1 or any ratio is outside ``[0, 1]``.
53+
If the ratios do not sum to 1, any ratio is outside ``[0, 1]``, or
54+
*label_start_col* is out of range.
4455
"""
4556
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
4657
raise ValueError("train_ratio + val_ratio + test_ratio must equal 1.0")
4758
if any(r < 0 or r > 1 for r in [train_ratio, val_ratio, test_ratio]):
4859
raise ValueError("All ratios must be between 0 and 1")
60+
if label_start_col >= len(df.columns):
61+
raise ValueError(
62+
f"label_start_col={label_start_col} is out of range for a DataFrame "
63+
f"with {len(df.columns)} columns"
64+
)
4965

50-
rng = np.random.default_rng(seed)
66+
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
67+
from sklearn.model_selection import StratifiedShuffleSplit
5168

52-
if stratify_col is not None:
53-
return _stratified_split(df, train_ratio, val_ratio, test_ratio, stratify_col, rng)
54-
return _random_split(df, train_ratio, val_ratio, test_ratio, rng)
69+
labels_matrix = df.iloc[:, label_start_col:].values
70+
is_multilabel = labels_matrix.shape[1] > 1
71+
# StratifiedShuffleSplit requires a 1-D label array
72+
y = labels_matrix if is_multilabel else labels_matrix[:, 0]
5573

74+
df_reset = df.reset_index(drop=True)
5675

57-
def _stratified_split(
58-
df: pd.DataFrame,
59-
train_ratio: float,
60-
val_ratio: float,
61-
test_ratio: float, # noqa: ARG001
62-
stratify_col: str,
63-
rng: np.random.Generator,
64-
) -> dict[str, pd.DataFrame]:
65-
train_indices: list[int] = []
66-
val_indices: list[int] = []
67-
test_indices: list[int] = []
76+
# ── Step 1: carve out the test set ──────────────────────────────────────
77+
if is_multilabel:
78+
test_splitter = MultilabelStratifiedShuffleSplit(
79+
n_splits=1, test_size=test_ratio, random_state=seed
80+
)
81+
else:
82+
test_splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=seed)
83+
train_val_idx, test_idx = next(test_splitter.split(y, y))
84+
85+
df_test = df_reset.iloc[test_idx]
86+
df_trainval = df_reset.iloc[train_val_idx]
87+
88+
# ── Step 2: split train/val from the remaining data ─────────────────────
89+
y_trainval = y[train_val_idx]
90+
val_ratio_adjusted = val_ratio / (1.0 - test_ratio)
6891

69-
for _, group in df.groupby(stratify_col, sort=False):
70-
group_indices = rng.permutation(np.array(group.index.tolist()))
71-
n = len(group_indices)
72-
n_train = max(1, int(n * train_ratio))
73-
n_val = max(0, int(n * val_ratio))
92+
if is_multilabel:
93+
val_splitter = MultilabelStratifiedShuffleSplit(
94+
n_splits=1, test_size=val_ratio_adjusted, random_state=seed
95+
)
96+
else:
97+
val_splitter = StratifiedShuffleSplit(
98+
n_splits=1, test_size=val_ratio_adjusted, random_state=seed
99+
)
100+
train_idx_inner, val_idx_inner = next(val_splitter.split(y_trainval, y_trainval))
74101

75-
train_indices.extend(group_indices[:n_train].tolist())
76-
val_indices.extend(group_indices[n_train : n_train + n_val].tolist())
77-
test_indices.extend(group_indices[n_train + n_val :].tolist())
102+
df_train = df_trainval.iloc[train_idx_inner]
103+
df_val = df_trainval.iloc[val_idx_inner]
78104

79105
return {
80-
"train": df.loc[train_indices].reset_index(drop=True),
81-
"val": df.loc[val_indices].reset_index(drop=True),
82-
"test": df.loc[test_indices].reset_index(drop=True),
106+
"train": df_train.reset_index(drop=True),
107+
"val": df_val.reset_index(drop=True),
108+
"test": df_test.reset_index(drop=True),
83109
}

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ license = { file = "LICENSE" }
1111
requires-python = ">=3.10"
1212
dependencies = [
1313
"fastobo>=0.14",
14+
"iterative-stratification>=0.1.9",
1415
"networkx>=3.0",
1516
"numpy>=1.24",
1617
"pandas>=2.0",
1718
"rdkit>=2022.09",
19+
"scikit-learn>=1.0",
1820
"chembl_structure_pipeline>=1.2.4",
1921
]
2022

tests/fixtures/sample.obo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,4 @@ xref: wikipedia.en:Starch {source="wikipedia.en"}
8484
is_a: CHEBI:37163 ! glucan
8585
relationship: BFO:0000051 CHEBI:28057 ! has part amylopectin
8686
relationship: BFO:0000051 CHEBI:28102 ! has part amylose
87-
relationship: RO:0000087 CHEBI:75771 ! has role mouse metabolite
87+
relationship: RO:0000087 CHEBI:75771 ! has role mouse metabolite

tests/test_obo_extractor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,28 @@ def test_returns_directed_graph(self):
1818
assert isinstance(g, nx.DiGraph)
1919

2020
def test_correct_number_of_nodes(self):
21-
# CHEBI:27189 is obsolete -> excluded; 3 explicit + 1 implicit (24921) = 4
21+
# CHEBI:27189 is obsolete -> excluded;
22+
# 4 explicit + 5 implicit (superclasses and relation targets) = 9
2223
g = build_chebi_graph(SAMPLE_OBO)
23-
assert len(g.nodes) == 4
24+
assert len(g.nodes) == 9
2425

2526
def test_node_ids_are_strings(self):
2627
g = build_chebi_graph(SAMPLE_OBO)
2728
assert all(isinstance(n, str) for n in g.nodes)
2829

2930
def test_expected_nodes_present(self):
3031
g = build_chebi_graph(SAMPLE_OBO)
31-
assert set(g.nodes) == {"10", "133004", "22750", "24921"}
32+
assert set(g.nodes) == {
33+
"10",
34+
"133004",
35+
"22750",
36+
"24921",
37+
"28017",
38+
"75771",
39+
"28057",
40+
"28102",
41+
"37163",
42+
}
3243

3344
def test_obsolete_term_excluded(self):
3445
g = build_chebi_graph(SAMPLE_OBO)
@@ -71,8 +82,8 @@ def test_isa_chain(self):
7182

7283
def test_total_edge_count(self):
7384
g = build_chebi_graph(SAMPLE_OBO)
74-
# 10->133004 (is_a), 133004->22750 (is_a), 22750->24921 (is_a)
75-
assert len(g.edges) == 3
85+
# 10->133004 (is_a), 133004->22750 (is_a), 22750->24921 (is_a), ...
86+
assert len(g.edges) == 7
7687

7788
def test_xref_lines_do_not_break_parsing(self, tmp_path):
7889
obo_with_xrefs = tmp_path / "xref.obo"

tests/test_sdf_extractor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_chebi_id_column_present(self):
2424

2525
def test_chebi_ids_correct(self):
2626
df = extract_molecules(SAMPLE_SDF)
27-
assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"}
27+
assert set(df["chebi_id"]) == {"1", "2"}
2828

2929
def test_name_column_present(self):
3030
df = extract_molecules(SAMPLE_SDF)
@@ -57,8 +57,8 @@ def test_mol_objects_are_rdkit_mol(self):
5757

5858
def test_mol_atom_counts(self):
5959
df = extract_molecules(SAMPLE_SDF)
60-
row1 = df[df["chebi_id"] == "CHEBI:1"].iloc[0]
61-
row2 = df[df["chebi_id"] == "CHEBI:2"].iloc[0]
60+
row1 = df[df["chebi_id"] == "1"].iloc[0]
61+
row2 = df[df["chebi_id"] == "2"].iloc[0]
6262
assert row1["mol"].GetNumAtoms() == 1 # methane: 1 C
6363
assert row2["mol"].GetNumAtoms() == 2 # ethane: 2 C
6464

@@ -70,7 +70,7 @@ def test_mol_sanitized(self):
7070

7171
def test_molecule_properties(self):
7272
df = extract_molecules(SAMPLE_SDF)
73-
row = df[df["chebi_id"] == "CHEBI:1"].iloc[0]
73+
row = df[df["chebi_id"] == "1"].iloc[0]
7474
assert row["name"] == "compound A"
7575
assert row["smiles"] == "C"
7676
assert row["formula"] == "CH4"
@@ -81,7 +81,7 @@ def test_gzipped_sdf(self, tmp_path):
8181
f_out.write(f_in.read())
8282
df = extract_molecules(gz_path)
8383
assert len(df) == 2
84-
assert set(df["chebi_id"]) == {"CHEBI:1", "CHEBI:2"}
84+
assert set(df["chebi_id"]) == {"1", "2"}
8585
assert all(isinstance(m, rdchem.Mol) for m in df["mol"])
8686

8787
def test_empty_sdf_returns_empty_dataframe(self, tmp_path):
@@ -90,11 +90,11 @@ def test_empty_sdf_returns_empty_dataframe(self, tmp_path):
9090
df = extract_molecules(empty_sdf)
9191
assert df.empty
9292

93-
def test_unparseable_molblock_gives_none(self, tmp_path, recwarn):
93+
def test_unparseable_molblock_excluded(self, tmp_path, recwarn):
9494
bad_sdf = tmp_path / "bad.sdf"
9595
bad_sdf.write_text(
9696
"bad_mol\n\n 0 0 0 0 0 0 0 0 0 0999 V2000\nM END\n"
9797
"> <ChEBI ID>\nCHEBI:99\n\n$$$$\n"
9898
)
9999
df = extract_molecules(bad_sdf)
100-
assert df.iloc[0]["mol"] is None
100+
assert len(df) == 0

0 commit comments

Comments
 (0)