|
| 1 | +"""Tests for chebi_utils.dataset_builder.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import networkx as nx |
| 6 | +import pandas as pd |
| 7 | +import pytest |
| 8 | +from rdkit import Chem |
| 9 | + |
| 10 | +from chebi_utils.dataset_builder import ( |
| 11 | + build_labeled_dataset, |
| 12 | +) |
| 13 | + |
| 14 | + |
| 15 | +def _make_mol(smiles: str) -> Chem.Mol: |
| 16 | + """Helper to create a sanitised Mol from SMILES.""" |
| 17 | + return Chem.MolFromSmiles(smiles) |
| 18 | + |
| 19 | + |
| 20 | +@pytest.fixture |
| 21 | +def simple_graph() -> nx.DiGraph: |
| 22 | + """Build a small ChEBI-like directed graph (child -> parent via is_a). |
| 23 | +
|
| 24 | + Hierarchy:: |
| 25 | +
|
| 26 | + A ─is_a─> B ─is_a─> D |
| 27 | + A ─is_a─> C ─is_a─> D |
| 28 | + E ─is_a─> C |
| 29 | +
|
| 30 | + Ontology descendants: |
| 31 | + D: {A, B, C, E} |
| 32 | + C: {A, E} |
| 33 | + B: {A} |
| 34 | + A: (none) |
| 35 | + E: (none) |
| 36 | + """ |
| 37 | + g = nx.DiGraph() |
| 38 | + for node in ["A", "B", "C", "D", "E"]: |
| 39 | + g.add_node(node, name=node, smiles=None, subset=None) |
| 40 | + |
| 41 | + g.add_edge("A", "B", relation="is_a") |
| 42 | + g.add_edge("A", "C", relation="is_a") |
| 43 | + g.add_edge("B", "D", relation="is_a") |
| 44 | + g.add_edge("C", "D", relation="is_a") |
| 45 | + g.add_edge("E", "C", relation="is_a") |
| 46 | + return g |
| 47 | + |
| 48 | + |
| 49 | +@pytest.fixture |
| 50 | +def simple_molecules() -> pd.DataFrame: |
| 51 | + """Three molecules with IDs A, B, E (matching graph nodes).""" |
| 52 | + return pd.DataFrame( |
| 53 | + { |
| 54 | + "chebi_id": ["A", "B", "E"], |
| 55 | + "mol": [_make_mol("C"), _make_mol("CC"), _make_mol("CCC")], |
| 56 | + } |
| 57 | + ) |
| 58 | + |
| 59 | + |
| 60 | +class TestBuildLabeledDataset: |
| 61 | + def test_returns_dataframe_with_base_columns(self, simple_graph, simple_molecules): |
| 62 | + df, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2) |
| 63 | + assert "chebi_id" in df.columns |
| 64 | + assert "mol" in df.columns |
| 65 | + # Label columns should also be present |
| 66 | + for lbl in labels: |
| 67 | + assert lbl in df.columns |
| 68 | + |
| 69 | + def test_one_row_per_molecule(self, simple_graph, simple_molecules): |
| 70 | + df, _ = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2) |
| 71 | + # 3 molecules with valid Mol -> 3 rows |
| 72 | + assert len(df) == 3 |
| 73 | + |
| 74 | + def test_label_columns_are_boolean(self, simple_graph, simple_molecules): |
| 75 | + df, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2) |
| 76 | + for lbl in labels: |
| 77 | + assert df[lbl].dtype == bool |
| 78 | + |
| 79 | + def test_one_hot_values_correct(self, simple_graph, simple_molecules): |
| 80 | + # Labels (min=2): {B, C, D} |
| 81 | + # A -> ancestors {A,B,C,D} -> B=True, C=True, D=True |
| 82 | + # B -> ancestors {B,D} -> B=True, C=False, D=True |
| 83 | + # E -> ancestors {E,C,D} -> B=False, C=True, D=True |
| 84 | + df, _ = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2) |
| 85 | + a_row = df[df["chebi_id"] == "A"].iloc[0] |
| 86 | + assert a_row["B"] == True # noqa: E712 |
| 87 | + assert a_row["C"] == True # noqa: E712 |
| 88 | + assert a_row["D"] == True # noqa: E712 |
| 89 | + |
| 90 | + b_row = df[df["chebi_id"] == "B"].iloc[0] |
| 91 | + assert b_row["B"] == True # noqa: E712 |
| 92 | + assert b_row["C"] == False # noqa: E712 |
| 93 | + assert b_row["D"] == True # noqa: E712 |
| 94 | + |
| 95 | + e_row = df[df["chebi_id"] == "E"].iloc[0] |
| 96 | + assert e_row["B"] == False # noqa: E712 |
| 97 | + assert e_row["C"] == True # noqa: E712 |
| 98 | + assert e_row["D"] == True # noqa: E712 |
| 99 | + |
| 100 | + def test_mol_objects_preserved(self, simple_graph, simple_molecules): |
| 101 | + df, _ = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=1) |
| 102 | + for _, row in df.iterrows(): |
| 103 | + assert isinstance(row["mol"], Chem.rdchem.Mol) |
| 104 | + |
| 105 | + def test_none_mols_are_excluded(self, simple_graph): |
| 106 | + mol_df = pd.DataFrame( |
| 107 | + { |
| 108 | + "chebi_id": ["A", "B"], |
| 109 | + "mol": [_make_mol("C"), None], |
| 110 | + } |
| 111 | + ) |
| 112 | + df, _ = build_labeled_dataset(simple_graph, mol_df, min_molecules=1) |
| 113 | + assert set(df["chebi_id"]) == {"A"} |
| 114 | + |
| 115 | + def test_high_threshold_returns_empty(self, simple_graph, simple_molecules): |
| 116 | + df, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=100) |
| 117 | + assert df.empty |
| 118 | + assert labels == [] |
| 119 | + |
| 120 | + def test_molecule_not_in_graph(self, simple_graph): |
| 121 | + """Molecules with chebi_ids not present in the graph are still handled.""" |
| 122 | + mol_df = pd.DataFrame( |
| 123 | + { |
| 124 | + "chebi_id": ["Z"], |
| 125 | + "mol": [_make_mol("C")], |
| 126 | + } |
| 127 | + ) |
| 128 | + df, labels = build_labeled_dataset(simple_graph, mol_df, min_molecules=1) |
| 129 | + assert "Z" in labels |
| 130 | + assert df.iloc[0]["Z"] == True # noqa: E712 |
| 131 | + |
| 132 | + def test_non_isa_edges_ignored(self): |
| 133 | + """Only is_a edges should be used for hierarchy traversal.""" |
| 134 | + g = nx.DiGraph() |
| 135 | + for n in ["X", "Y", "Z"]: |
| 136 | + g.add_node(n, name=n, smiles=None, subset=None) |
| 137 | + g.add_edge("X", "Y", relation="is_a") |
| 138 | + g.add_edge("X", "Z", relation="has_part") |
| 139 | + |
| 140 | + mol_df = pd.DataFrame( |
| 141 | + { |
| 142 | + "chebi_id": ["X"], |
| 143 | + "mol": [_make_mol("C")], |
| 144 | + } |
| 145 | + ) |
| 146 | + df, labels = build_labeled_dataset(g, mol_df, min_molecules=1) |
| 147 | + # X is_a Y, so labels should include X and Y (but NOT Z via has_part) |
| 148 | + assert set(labels) == {"X", "Y"} |
| 149 | + assert df.iloc[0]["X"] == True # noqa: E712 |
| 150 | + assert df.iloc[0]["Y"] == True # noqa: E712 |
| 151 | + |
| 152 | + def test_empty_molecules_dataframe(self, simple_graph): |
| 153 | + mol_df = pd.DataFrame(columns=["chebi_id", "mol"]) |
| 154 | + df, labels = build_labeled_dataset(simple_graph, mol_df, min_molecules=1) |
| 155 | + assert df.empty |
| 156 | + assert labels == [] |
| 157 | + |
| 158 | + def test_returned_labels_list_sorted(self, simple_graph, simple_molecules): |
| 159 | + _, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=2) |
| 160 | + assert labels == ["B", "C", "D"] |
| 161 | + |
| 162 | + def test_returned_labels_match_columns(self, simple_graph, simple_molecules): |
| 163 | + df, labels = build_labeled_dataset(simple_graph, simple_molecules, min_molecules=1) |
| 164 | + label_cols = [c for c in df.columns if c not in ("chebi_id", "mol")] |
| 165 | + assert label_cols == labels |
0 commit comments