Skip to content

Commit d4cdf02

Browse files
committed
merge dev
2 parents d7c45fb + 1da984f commit d4cdf02

File tree

5 files changed

+285
-1
lines changed

5 files changed

+285
-1
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ ipython_config.py
9898
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
9999
# This is especially recommended for binary packages to ensure reproducibility, and is more
100100
# commonly ignored for libraries.
101-
#uv.lock
101+
uv.lock
102102

103103
# poetry
104104
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.

chebi_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from chebi_utils.dataset_builder import build_labeled_dataset
12
from chebi_utils.downloader import download_chebi_obo, download_chebi_sdf
23
from chebi_utils.obo_extractor import build_chebi_graph
34
from chebi_utils.sdf_extractor import extract_molecules
45
from chebi_utils.splitter import create_multilabel_splits
56

67
__all__ = [
8+
"build_labeled_dataset",
79
"download_chebi_obo",
810
"download_chebi_sdf",
911
"build_chebi_graph",

chebi_utils/dataset_builder.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Build a labeled dataset by matching molecules to ChEBI ontology classes."""
2+
3+
from __future__ import annotations
4+
5+
from collections import Counter
6+
7+
import networkx as nx
8+
import pandas as pd
9+
10+
from chebi_utils.obo_extractor import get_hierarchy_subgraph
11+
12+
13+
def _count_molecules_per_class(closure: nx.DiGraph, mol_ids: set[str]) -> dict[str, int]:
14+
"""Count how many molecules fall under each ontology class.
15+
16+
Uses the precomputed transitive closure so that ancestor look-ups are O(1).
17+
18+
Parameters
19+
----------
20+
closure : nx.DiGraph
21+
Transitive closure of the ``is_a`` hierarchy.
22+
mol_ids : set[str]
23+
ChEBI IDs of molecules with valid ``Mol`` objects.
24+
25+
Returns
26+
-------
27+
dict[str, int]
28+
Mapping from ChEBI class ID to count of molecules in its subtree.
29+
"""
30+
counts: Counter[str] = Counter()
31+
for mid in mol_ids:
32+
if mid in closure:
33+
for ancestor in closure.successors(mid):
34+
counts[ancestor] += 1
35+
# The molecule itself always counts for its own class
36+
counts[mid] += 1
37+
return dict(counts)
38+
39+
40+
def build_labeled_dataset(
41+
chebi_graph: nx.DiGraph,
42+
molecules: pd.DataFrame,
43+
min_molecules: int = 50,
44+
) -> tuple[pd.DataFrame, list[str]]:
45+
"""Build a labeled dataset matching molecules to ontology classes.
46+
47+
Each molecule is assigned to every selected label class that it belongs to
48+
(directly or through a chain of ``is_a`` relationships). Only classes with
49+
at least *min_molecules* descendant molecules (including indirect
50+
descendants) are retained as labels.
51+
52+
Labels are encoded **one-hot**: the returned DataFrame contains one boolean
53+
column per selected label.
54+
55+
Parameters
56+
----------
57+
chebi_graph : nx.DiGraph
58+
Full ChEBI ontology graph from :func:`build_chebi_graph`.
59+
molecules : pd.DataFrame
60+
DataFrame from :func:`extract_molecules` containing at least
61+
``chebi_id`` and ``mol`` columns.
62+
min_molecules : int
63+
Minimum number of descendant molecules a class must have to be
64+
selected as a label (default 50).
65+
66+
Returns
67+
-------
68+
tuple[pd.DataFrame, list[str]]
69+
A tuple of:
70+
- DataFrame with columns ``chebi_id``, ``mol``, and one boolean
71+
column per selected label. Each row represents one molecule.
72+
- Sorted list of selected label ChEBI IDs.
73+
"""
74+
# Keep only molecules with a valid Mol object
75+
mol_df = molecules[molecules["mol"].notna()].copy()
76+
mol_ids = set(mol_df["chebi_id"])
77+
78+
# Build transitive closure of hierarchy once
79+
hierarchy = get_hierarchy_subgraph(chebi_graph)
80+
closure = nx.transitive_closure_dag(hierarchy)
81+
82+
# Determine label set
83+
counts = _count_molecules_per_class(closure, mol_ids)
84+
labels = {cls for cls, count in counts.items() if count >= min_molecules}
85+
sorted_labels = sorted(labels)
86+
87+
if not labels:
88+
return pd.DataFrame(columns=["chebi_id", "mol"]), sorted_labels
89+
90+
# For each molecule compute its ancestor set (including itself) via closure
91+
label_matrix: list[dict[str, bool]] = []
92+
for cid in mol_df["chebi_id"]:
93+
if cid in closure:
94+
ancestors = set(closure.successors(cid)) | {cid}
95+
else:
96+
ancestors = {cid}
97+
mol_labels = ancestors & labels
98+
label_matrix.append({lbl: lbl in mol_labels for lbl in sorted_labels})
99+
100+
label_df = pd.DataFrame(label_matrix, index=mol_df.index)
101+
result = pd.concat(
102+
[mol_df[["chebi_id", "mol"]].reset_index(drop=True), label_df.reset_index(drop=True)],
103+
axis=1,
104+
)
105+
106+
return result, sorted_labels

chebi_utils/sdf_extractor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import pandas as pd
1010
from rdkit import Chem
1111

12+
from chebi_utils.obo_extractor import _chebi_id_to_str
13+
1214

1315
def _sanitize_molecule(mol: Chem.Mol) -> Chem.Mol:
1416
"""Sanitize molecule, mirroring the ChEBI molecule processing."""
@@ -157,5 +159,14 @@ def extract_molecules(filepath: str | Path) -> pd.DataFrame:
157159

158160
chebi_ids = df["chebi_id"].tolist() if "chebi_id" in df.columns else [None] * len(df)
159161
df["mol"] = [_parse_molblock(mb, cid) for mb, cid in zip(molblocks, chebi_ids, strict=False)]
162+
df["chebi_id"] = df["chebi_id"].apply(_chebi_id_to_str)
163+
164+
# exclude records without a valid mol, but keep the same columns for consistency
165+
df = df[df["mol"].notna()]
160166

161167
return df
168+
169+
170+
if __name__ == "__main__":
171+
df = extract_molecules("data/chebi.sdf.gz")
172+
print(df.head())

tests/test_dataset_builder.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""Tests for chebi_utils.dataset_builder."""
2+
3+
from __future__ import annotations
4+
5+
import networkx as nx
6+
import pandas as pd
7+
import pytest
8+
from rdkit import Chem
9+
10+
from chebi_utils.dataset_builder import (
11+
build_labeled_dataset,
12+
)
13+
14+
15+
def _make_mol(smiles: str) -> Chem.Mol:
16+
"""Helper to create a sanitised Mol from SMILES."""
17+
return Chem.MolFromSmiles(smiles)
18+
19+
20+
@pytest.fixture
21+
def simple_graph() -> nx.DiGraph:
22+
"""Build a small ChEBI-like directed graph (child -> parent via is_a).
23+
24+
Hierarchy::
25+
26+
A ─is_a─> B ─is_a─> D
27+
A ─is_a─> C ─is_a─> D
28+
E ─is_a─> C
29+
30+
Ontology descendants:
31+
D: {A, B, C, E}
32+
C: {A, E}
33+
B: {A}
34+
A: (none)
35+
E: (none)
36+
"""
37+
g = nx.DiGraph()
38+
for node in ["A", "B", "C", "D", "E"]:
39+
g.add_node(node, name=node, smiles=None, subset=None)
40+
41+
g.add_edge("A", "B", relation="is_a")
42+
g.add_edge("A", "C", relation="is_a")
43+
g.add_edge("B", "D", relation="is_a")
44+
g.add_edge("C", "D", relation="is_a")
45+
g.add_edge("E", "C", relation="is_a")
46+
return g
47+
48+
49+
@pytest.fixture
50+
def simple_molecules() -> pd.DataFrame:
51+
"""Three molecules with IDs A, B, E (matching graph nodes)."""
52+
return pd.DataFrame(
53+
{
54+
"chebi_id": ["A", "B", "E"],
55+
"mol": [_make_mol("C"), _make_mol("CC"), _make_mol("CCC")],
56+
}
57+
)
58+
59+
60+
class TestBuildLabeledDataset:
61+
def test_returns_dataframe_with_base_columns(self, simple_graph, simple_molecules):
62+
df, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2)
63+
assert "chebi_id" in df.columns
64+
assert "mol" in df.columns
65+
# Label columns should also be present
66+
for lbl in labels:
67+
assert lbl in df.columns
68+
69+
def test_one_row_per_molecule(self, simple_graph, simple_molecules):
70+
df, _ = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2)
71+
# 3 molecules with valid Mol -> 3 rows
72+
assert len(df) == 3
73+
74+
def test_label_columns_are_boolean(self, simple_graph, simple_molecules):
75+
df, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2)
76+
for lbl in labels:
77+
assert df[lbl].dtype == bool
78+
79+
def test_one_hot_values_correct(self, simple_graph, simple_molecules):
80+
# Labels (min=2): {B, C, D}
81+
# A -> ancestors {A,B,C,D} -> B=True, C=True, D=True
82+
# B -> ancestors {B,D} -> B=True, C=False, D=True
83+
# E -> ancestors {E,C,D} -> B=False, C=True, D=True
84+
df, _ = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2)
85+
a_row = df[df["chebi_id"] == "A"].iloc[0]
86+
assert a_row["B"] == True # noqa: E712
87+
assert a_row["C"] == True # noqa: E712
88+
assert a_row["D"] == True # noqa: E712
89+
90+
b_row = df[df["chebi_id"] == "B"].iloc[0]
91+
assert b_row["B"] == True # noqa: E712
92+
assert b_row["C"] == False # noqa: E712
93+
assert b_row["D"] == True # noqa: E712
94+
95+
e_row = df[df["chebi_id"] == "E"].iloc[0]
96+
assert e_row["B"] == False # noqa: E712
97+
assert e_row["C"] == True # noqa: E712
98+
assert e_row["D"] == True # noqa: E712
99+
100+
def test_mol_objects_preserved(self, simple_graph, simple_molecules):
101+
df, _ = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=1)
102+
for _, row in df.iterrows():
103+
assert isinstance(row["mol"], Chem.rdchem.Mol)
104+
105+
def test_none_mols_are_excluded(self, simple_graph):
106+
mol_df = pd.DataFrame(
107+
{
108+
"chebi_id": ["A", "B"],
109+
"mol": [_make_mol("C"), None],
110+
}
111+
)
112+
df, _ = build_labeled_dataset(simple_graph, mol_df, min_molecules=1)
113+
assert set(df["chebi_id"]) == {"A"}
114+
115+
def test_high_threshold_returns_empty(self, simple_graph, simple_molecules):
116+
df, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=100)
117+
assert df.empty
118+
assert labels == []
119+
120+
def test_molecule_not_in_graph(self, simple_graph):
121+
"""Molecules with chebi_ids not present in the graph are still handled."""
122+
mol_df = pd.DataFrame(
123+
{
124+
"chebi_id": ["Z"],
125+
"mol": [_make_mol("C")],
126+
}
127+
)
128+
df, labels = build_labeled_dataset(simple_graph, mol_df, min_molecules=1)
129+
assert "Z" in labels
130+
assert df.iloc[0]["Z"] == True # noqa: E712
131+
132+
def test_non_isa_edges_ignored(self):
133+
"""Only is_a edges should be used for hierarchy traversal."""
134+
g = nx.DiGraph()
135+
for n in ["X", "Y", "Z"]:
136+
g.add_node(n, name=n, smiles=None, subset=None)
137+
g.add_edge("X", "Y", relation="is_a")
138+
g.add_edge("X", "Z", relation="has_part")
139+
140+
mol_df = pd.DataFrame(
141+
{
142+
"chebi_id": ["X"],
143+
"mol": [_make_mol("C")],
144+
}
145+
)
146+
df, labels = build_labeled_dataset(g, mol_df, min_molecules=1)
147+
# X is_a Y, so labels should include X and Y (but NOT Z via has_part)
148+
assert set(labels) == {"X", "Y"}
149+
assert df.iloc[0]["X"] == True # noqa: E712
150+
assert df.iloc[0]["Y"] == True # noqa: E712
151+
152+
def test_empty_molecules_dataframe(self, simple_graph):
153+
mol_df = pd.DataFrame(columns=["chebi_id", "mol"])
154+
df, labels = build_labeled_dataset(simple_graph, mol_df, min_molecules=1)
155+
assert df.empty
156+
assert labels == []
157+
158+
def test_returned_labels_list_sorted(self, simple_graph, simple_molecules):
159+
_, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2)
160+
assert labels == ["B", "C", "D"]
161+
162+
def test_returned_labels_match_columns(self, simple_graph, simple_molecules):
163+
df, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=1)
164+
label_cols = [c for c in df.columns if c not in ("chebi_id", "mol")]
165+
assert label_cols == labels

0 commit comments

Comments
 (0)