|
5 | 5 | import pandas as pd |
6 | 6 | import pytest |
7 | 7 |
|
8 | | -from chebi_utils.splitter import create_splits |
| 8 | +from chebi_utils.splitter import create_multilabel_splits, create_splits |
9 | 9 |
|
10 | 10 |
|
11 | 11 | @pytest.fixture |
@@ -103,3 +103,89 @@ def test_stratified_reproducible(self, sample_df): |
103 | 103 | splits1 = create_splits(sample_df, stratify_col="category", seed=42) |
104 | 104 | splits2 = create_splits(sample_df, stratify_col="category", seed=42) |
105 | 105 | pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) |
| 106 | + |
| 107 | + |
| 108 | +@pytest.fixture |
| 109 | +def multilabel_df(): |
| 110 | + """DataFrame with multilabel 'labels' column (200 rows).""" |
| 111 | + all_labels = [["A"], ["B"], ["C"], ["A", "B"], ["A", "C"], ["B", "C"]] |
| 112 | + labels = [all_labels[i % len(all_labels)] for i in range(200)] |
| 113 | + return pd.DataFrame( |
| 114 | + { |
| 115 | + "id": [f"CHEBI:{i}" for i in range(200)], |
| 116 | + "labels": labels, |
| 117 | + } |
| 118 | + ) |
| 119 | + |
| 120 | + |
| 121 | +@pytest.fixture |
| 122 | +def singlelabel_df(): |
| 123 | + """DataFrame with single-label 'labels' column.""" |
| 124 | + return pd.DataFrame( |
| 125 | + { |
| 126 | + "id": [f"CHEBI:{i}" for i in range(200)], |
| 127 | + "labels": [["A"] if i % 2 == 0 else ["B"] for i in range(200)], |
| 128 | + } |
| 129 | + ) |
| 130 | + |
| 131 | + |
| 132 | +class TestCreateMultilabelSplits: |
| 133 | + def test_returns_three_splits(self, multilabel_df): |
| 134 | + splits = create_multilabel_splits(multilabel_df, labels_col="labels") |
| 135 | + assert set(splits.keys()) == {"train", "val", "test"} |
| 136 | + |
| 137 | + def test_sizes_sum_to_total(self, multilabel_df): |
| 138 | + splits = create_multilabel_splits(multilabel_df, labels_col="labels") |
| 139 | + assert sum(len(v) for v in splits.values()) == len(multilabel_df) |
| 140 | + |
| 141 | + def test_no_overlap(self, multilabel_df): |
| 142 | + splits = create_multilabel_splits(multilabel_df, labels_col="labels") |
| 143 | + train_ids = set(splits["train"]["id"]) |
| 144 | + val_ids = set(splits["val"]["id"]) |
| 145 | + test_ids = set(splits["test"]["id"]) |
| 146 | + assert train_ids.isdisjoint(val_ids) |
| 147 | + assert train_ids.isdisjoint(test_ids) |
| 148 | + assert val_ids.isdisjoint(test_ids) |
| 149 | + |
| 150 | + def test_all_rows_covered(self, multilabel_df): |
| 151 | + splits = create_multilabel_splits(multilabel_df, labels_col="labels") |
| 152 | + all_ids = set(splits["train"]["id"]) | set(splits["val"]["id"]) | set(splits["test"]["id"]) |
| 153 | + assert all_ids == set(multilabel_df["id"]) |
| 154 | + |
| 155 | + def test_reproducible_with_same_seed(self, multilabel_df): |
| 156 | + splits1 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=7) |
| 157 | + splits2 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=7) |
| 158 | + pd.testing.assert_frame_equal(splits1["train"], splits2["train"]) |
| 159 | + |
| 160 | + def test_different_seeds_give_different_splits(self, multilabel_df): |
| 161 | + splits1 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=1) |
| 162 | + splits2 = create_multilabel_splits(multilabel_df, labels_col="labels", seed=2) |
| 163 | + assert not splits1["train"]["id"].equals(splits2["train"]["id"]) |
| 164 | + |
| 165 | + def test_approximate_split_sizes(self, multilabel_df): |
| 166 | + splits = create_multilabel_splits( |
| 167 | + multilabel_df, labels_col="labels", train_ratio=0.8, val_ratio=0.1, test_ratio=0.1 |
| 168 | + ) |
| 169 | + n = len(multilabel_df) |
| 170 | + assert abs(len(splits["test"]) - int(n * 0.1)) <= 2 |
| 171 | + assert abs(len(splits["val"]) - int(n * 0.1)) <= 2 |
| 172 | + |
| 173 | + def test_invalid_ratios_raise_error(self, multilabel_df): |
| 174 | + with pytest.raises(ValueError, match="must equal 1.0"): |
| 175 | + create_multilabel_splits( |
| 176 | + multilabel_df, labels_col="labels", train_ratio=0.5, val_ratio=0.3, test_ratio=0.3 |
| 177 | + ) |
| 178 | + |
| 179 | + def test_missing_labels_col_raises_error(self, multilabel_df): |
| 180 | + with pytest.raises(ValueError, match="not found in DataFrame"): |
| 181 | + create_multilabel_splits(multilabel_df, labels_col="nonexistent") |
| 182 | + |
| 183 | + def test_singlelabel_path(self, singlelabel_df): |
| 184 | + """Single-label lists should use StratifiedShuffleSplit without error.""" |
| 185 | + splits = create_multilabel_splits(singlelabel_df, labels_col="labels") |
| 186 | + assert sum(len(v) for v in splits.values()) == len(singlelabel_df) |
| 187 | + train_ids = set(splits["train"]["id"]) |
| 188 | + val_ids = set(splits["val"]["id"]) |
| 189 | + test_ids = set(splits["test"]["id"]) |
| 190 | + assert train_ids.isdisjoint(val_ids) |
| 191 | + assert train_ids.isdisjoint(test_ids) |
0 commit comments