From fd1ecf3f8a67273555f35aaa61d34d9ff1224607 Mon Sep 17 00:00:00 2001 From: batukav Date: Wed, 27 Aug 2025 13:41:37 +0200 Subject: [PATCH 1/4] allow groups to be explicitely passed to train_test_group_split --- scikit_mol/splitter.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/scikit_mol/splitter.py b/scikit_mol/splitter.py index 8f5e10b..5c87524 100644 --- a/scikit_mol/splitter.py +++ b/scikit_mol/splitter.py @@ -262,7 +262,10 @@ def train_test_group_split( *arrays : sequence of indexables with same length / shape[0] Allowed inputs are lists, numpy arrays, scipy-sparse matrices or pandas dataframes. The last array must be the `groups` - array. + array, unless groups variable is set. + + groups : sequence of indexables with the same length / shape[0] + Same as *arrays. We allow users to explicity set the groups. test_size : float or int, default=None If float, should be between 0.0 and 1.0 and represent the proportion @@ -298,16 +301,20 @@ def train_test_group_split( splitting : list, length=2 * len(arrays) List containing train-test split of inputs. """ + + if groups is not None: + arrays = arrays + (groups,) n_arrays = len(arrays) - if n_arrays < 2: - raise ValueError( - "At least two arrays are required as input (e.g., X, groups)." - ) + if n_arrays < 3: + raise ValueError("At least two arrays are required as input (e.g., X, groups).") arrays = indexable(*arrays) - groups = arrays[-1] n_samples = _num_samples(arrays[0]) + + assert ( + len(groups) == n_samples + ), f"groups and input arrays should have the same length {len(groups)} != {n_samples}" n_train, n_test = _validate_shuffle_split( n_samples, test_size, train_size, default_test_size=0.25 ) From d6b3a9cacac2bd0cdbab744dde88a8e9582cd1a8 Mon Sep 17 00:00:00 2001 From: batukav Date: Wed, 27 Aug 2025 13:41:43 +0200 Subject: [PATCH 2/4] code formatting --- scikit_mol/splitter.py | 91 +++++++++++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 27 deletions(-) diff --git a/scikit_mol/splitter.py b/scikit_mol/splitter.py index 5c87524..f093a7e 100644 --- a/scikit_mol/splitter.py +++ b/scikit_mol/splitter.py @@ -23,10 +23,10 @@ def __init__( n_splits: int = 5, *, test_size: float or int = 0.2, - train_size: float or int=None, + train_size: float or int = None, random_state: int = None, sample_weighted: bool = False, - suppress_warnings: bool = False + suppress_warnings: bool = False, ): super().__init__( n_splits=n_splits, @@ -41,11 +41,16 @@ def __init__( if self.sample_weighted: warnings.warn( f"sample_weighted = True. During the test split, groups with more samples will be prioritized", - UserWarning, - ) + UserWarning, + ) + + def _iter_indices( + self, + X: Union[List, np.ndarray, pd.Series], + y: Union[List, np.ndarray, pd.Series], + groups: Union[List, np.ndarray, pd.Series], + ): - def _iter_indices(self, X: Union[List, np.ndarray, pd.Series], y: Union[List, np.ndarray, pd.Series], groups: Union[List, np.ndarray, pd.Series]): - if y is None: raise ValueError( "StratifiedGroupShuffleSplit requires 'y' for stratification." @@ -117,7 +122,9 @@ def _iter_indices(self, X: Union[List, np.ndarray, pd.Series], y: Union[List, np pool_size = min(5, len(safe_candidates)) candidate_pool = [cand["id"] for cand in safe_candidates[:pool_size]] if self.sample_weighted: - weights = [group_info[group_idx]["size"] for group_idx in candidate_pool] + weights = [ + group_info[group_idx]["size"] for group_idx in candidate_pool + ] best_group = rng.choice(candidate_pool, p=weights) else: best_group = rng.choice(candidate_pool) @@ -155,7 +162,9 @@ def _iter_indices(self, X: Union[List, np.ndarray, pd.Series], y: Union[List, np if valid_overshoot_candidates: # Randomly choose from the valid overshooting groups - best_overshoot_group_id = rng.choice(valid_overshoot_candidates) + best_overshoot_group_id = rng.choice( + valid_overshoot_candidates + ) test_groups.append(best_overshoot_group_id) test_indices = ( @@ -164,21 +173,27 @@ def _iter_indices(self, X: Union[List, np.ndarray, pd.Series], y: Union[List, np else [] ) if len(test_indices) == 0: - raise RuntimeError(f"Given the dataset, no train/test split could be found. Try increasing test_size") + raise RuntimeError( + f"Given the dataset, no train/test split could be found. Try increasing test_size" + ) all_indices = np.arange(n_samples) train_indices = np.setdiff1d(all_indices, test_indices, assume_unique=True) - + if isinstance(self.test_size, float): - + requested_test_size_ratio = self.test_size else: requested_test_size_ratio = self.test_size / n_samples - - test_size_error = np.abs(len(test_indices)/n_samples - requested_test_size_ratio) - + + test_size_error = np.abs( + len(test_indices) / n_samples - requested_test_size_ratio + ) + if not self.suppress_warnings: - if test_size_error > 0.05: # 5% deviation - warnings.warn(f"Requested and calculated test sizes differ by {test_size_error*100:.2f}%") + if test_size_error > 0.05: # 5% deviation + warnings.warn( + f"Requested and calculated test sizes differ by {test_size_error*100:.2f}%" + ) yield train_indices, test_indices @@ -220,27 +235,36 @@ def _check_split_viability(self, n_test, unique_groups, group_counts): if group_count >= n_test: n_groups += 1 too_large_groups[group_id] = group_count - if len(too_large_groups) > 0 and not self.suppress_warnings and n_groups < len(unique_groups): + if ( + len(too_large_groups) > 0 + and not self.suppress_warnings + and n_groups < len(unique_groups) + ): warnings.warn( - f''' + f""" Some groups are too large for the test set and will never be present in the test set: {too_large_groups}.\n If you want a group to be able to be present in the test set, test_size >= group_size. - ''', + """, UserWarning, ) - elif len(too_large_groups) > 0 and not self.suppress_warnings and n_groups == len(unique_groups): + elif ( + len(too_large_groups) > 0 + and not self.suppress_warnings + and n_groups == len(unique_groups) + ): warnings.warn( - ''' + """ "Warning: All available groups are larger than the target test size. The algorithm will still try to select a group that overshoots the target, which may lead to a larger than requested test set, or an completely empty test set." - ''', + """, UserWarning, ) - + def train_test_group_split( *arrays, + groups=None, test_size=None, train_size=None, random_state=None, @@ -343,7 +367,9 @@ def train_test_group_split( else: # stratify is None CVClass = GroupShuffleSplit - cv = CVClass(n_splits=1, test_size=n_test, train_size=n_train, random_state=random_state) + cv = CVClass( + n_splits=1, test_size=n_test, train_size=n_train, random_state=random_state + ) train, test = next(cv.split(X=arrays[0], y=y_for_split, groups=groups)) @@ -387,7 +413,16 @@ class GroupSplitCV: Whether to perform stratified sampling. If True, the `y` parameter in the `split` method is used for stratification. """ - def __init__(self, n_splits=5, *, test_size=0.2, train_size=None, random_state=None, stratify=False): + + def __init__( + self, + n_splits=5, + *, + test_size=0.2, + train_size=None, + random_state=None, + stratify=False, + ): self.n_splits = n_splits self.test_size = test_size self.train_size = train_size @@ -422,7 +457,9 @@ def split(self, X, y=None, groups=None): """ if self.stratify: if y is None: - raise ValueError("The 'y' parameter should not be None when stratify=True.") + raise ValueError( + "The 'y' parameter should not be None when stratify=True." + ) cv = StratifiedGroupShuffleSplit( n_splits=self.n_splits, test_size=self.test_size, @@ -441,4 +478,4 @@ def split(self, X, y=None, groups=None): def get_n_splits(self, X=None, y=None, groups=None): """Returns the number of splitting iterations in the cross-validator.""" - return self.n_splits \ No newline at end of file + return self.n_splits From 0f06dd3e53b66d9c345121180d2501e134871fa1 Mon Sep 17 00:00:00 2001 From: batukav Date: Wed, 27 Aug 2025 13:57:44 +0200 Subject: [PATCH 3/4] fix print_report bug causing overlapping groups. Rerun the downstream analysis using the subset.csv --- ...upShuffleSplit_and_MurckoTransformer.ipynb | 475 +++++++++--------- 1 file changed, 232 insertions(+), 243 deletions(-) diff --git a/scikit_mol/notebooks/StratifiedGroupShuffleSplit_and_MurckoTransformer.ipynb b/scikit_mol/notebooks/StratifiedGroupShuffleSplit_and_MurckoTransformer.ipynb index 1e81cb3..86674b8 100644 --- a/scikit_mol/notebooks/StratifiedGroupShuffleSplit_and_MurckoTransformer.ipynb +++ b/scikit_mol/notebooks/StratifiedGroupShuffleSplit_and_MurckoTransformer.ipynb @@ -303,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -312,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -327,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -389,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -414,12 +414,12 @@ " print(f\"Test groups and their counts: {groups_test_counts}\")\n", " print(f\"Train size: {len(y[train_index])}\")\n", " print(f\"Test size: {len(y[test_index])}\")\n", - " print(f\"Overlapping groups in train and test splits: {len(set(groups_balanced[train_index]).intersection(set(groups_balanced[test_index])))}\\n\")\n" + " print(f\"Overlapping groups in train and test splits: {len(set(groups[train_index]).intersection(set(groups[test_index])))}\\n\")\n" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -431,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -447,35 +447,35 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([318, 331, 351]))\n", - "Unique groups and their counts in the full dataset: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 80, 99, 122, 102, 114, 91, 105, 92, 99, 96]))\n", + "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([349, 331, 320]))\n", + "Unique groups and their counts in the full dataset: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 79, 104, 114, 105, 84, 105, 91, 112, 99, 107]))\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([255, 268, 267]))\n", - "Test y and their counts: (array([0, 1, 2]), array([63, 63, 84]))\n", - "Train set class label distribution: [0.32278481 0.33924051 0.33797468]\n", - "Test set class label distribution: [0.3 0.3 0.4]\n", - "Train groups and their counts:: (array([0, 1, 2, 3, 5, 6, 7, 8]), array([ 80, 99, 122, 102, 91, 105, 92, 99]))\n", - "Test groups and their counts: (array([4, 9]), array([114, 96]))\n", - "Train size: 790\n", - "Test size: 210\n", + "Train y and their counts: (array([0, 1, 2]), array([265, 275, 256]))\n", + "Test y and their counts: (array([0, 1, 2]), array([84, 56, 64]))\n", + "Train set class label distribution: [0.33291457 0.34547739 0.32160804]\n", + "Test set class label distribution: [0.41176471 0.2745098 0.31372549]\n", + "Train groups and their counts:: (array([0, 1, 2, 3, 4, 6, 7, 9]), array([ 79, 104, 114, 105, 84, 91, 112, 107]))\n", + "Test groups and their counts: (array([5, 8]), array([105, 99]))\n", + "Train size: 796\n", + "Test size: 204\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([260, 255, 287]))\n", - "Test y and their counts: (array([0, 1, 2]), array([58, 76, 64]))\n", - "Train set class label distribution: [0.32418953 0.31795511 0.35785536]\n", - "Test set class label distribution: [0.29292929 0.38383838 0.32323232]\n", - "Train groups and their counts:: (array([0, 2, 3, 4, 5, 6, 7, 9]), array([ 80, 122, 102, 114, 91, 105, 92, 96]))\n", - "Test groups and their counts: (array([1, 8]), array([99, 99]))\n", - "Train size: 802\n", - "Test size: 198\n", + "Train y and their counts: (array([0, 1, 2]), array([272, 272, 247]))\n", + "Test y and their counts: (array([0, 1, 2]), array([77, 59, 73]))\n", + "Train set class label distribution: [0.34386852 0.34386852 0.31226296]\n", + "Test set class label distribution: [0.36842105 0.28229665 0.3492823 ]\n", + "Train groups and their counts:: (array([0, 2, 3, 4, 6, 7, 8, 9]), array([ 79, 114, 105, 84, 91, 112, 99, 107]))\n", + "Test groups and their counts: (array([1, 5]), array([104, 105]))\n", + "Train size: 791\n", + "Test size: 209\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -492,36 +492,36 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([932, 39, 29]))\n", - "Unique groups and their counts in the full dataset: (array([0, 1]), array([ 70, 930]))\n", + "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([930, 37, 33]))\n", + "Unique groups and their counts in the full dataset: (array([0, 1]), array([ 75, 925]))\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([64, 3, 3]))\n", - "Test y and their counts: (array([0, 1, 2]), array([868, 36, 26]))\n", - "Train set class label distribution: [0.91428571 0.04285714 0.04285714]\n", - "Test set class label distribution: [0.93333333 0.03870968 0.02795699]\n", - "Train groups and their counts:: (array([0]), array([70]))\n", - "Test groups and their counts: (array([1]), array([930]))\n", - "Train size: 70\n", - "Test size: 930\n", - "Overlapping groups in train and test splits: 10\n", + "Train y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", + "Test y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", + "Train set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Test set class label distribution: [0.92864865 0.03675676 0.03459459]\n", + "Train groups and their counts:: (array([0]), array([75]))\n", + "Test groups and their counts: (array([1]), array([925]))\n", + "Train size: 75\n", + "Test size: 925\n", + "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([868, 36, 26]))\n", - "Test y and their counts: (array([0, 1, 2]), array([64, 3, 3]))\n", - "Train set class label distribution: [0.93333333 0.03870968 0.02795699]\n", - "Test set class label distribution: [0.91428571 0.04285714 0.04285714]\n", - "Train groups and their counts:: (array([1]), array([930]))\n", - "Test groups and their counts: (array([0]), array([70]))\n", - "Train size: 930\n", - "Test size: 70\n", - "Overlapping groups in train and test splits: 10\n", + "Train y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", + "Test y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", + "Train set class label distribution: [0.92864865 0.03675676 0.03459459]\n", + "Test set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Train groups and their counts:: (array([1]), array([925]))\n", + "Test groups and their counts: (array([0]), array([75]))\n", + "Train size: 925\n", + "Test size: 75\n", + "Overlapping groups in train and test splits: 0\n", "\n" ] } @@ -543,36 +543,36 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([932, 39, 29]))\n", - "Unique groups and their counts in the full dataset: (array([0, 1]), array([ 70, 930]))\n", + "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([930, 37, 33]))\n", + "Unique groups and their counts in the full dataset: (array([0, 1]), array([ 75, 925]))\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([64, 3, 3]))\n", - "Test y and their counts: (array([0, 1, 2]), array([868, 36, 26]))\n", - "Train set class label distribution: [0.91428571 0.04285714 0.04285714]\n", - "Test set class label distribution: [0.93333333 0.03870968 0.02795699]\n", - "Train groups and their counts:: (array([0]), array([70]))\n", - "Test groups and their counts: (array([1]), array([930]))\n", - "Train size: 70\n", - "Test size: 930\n", - "Overlapping groups in train and test splits: 10\n", + "Train y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", + "Test y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", + "Train set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Test set class label distribution: [0.92864865 0.03675676 0.03459459]\n", + "Train groups and their counts:: (array([0]), array([75]))\n", + "Test groups and their counts: (array([1]), array([925]))\n", + "Train size: 75\n", + "Test size: 925\n", + "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([868, 36, 26]))\n", - "Test y and their counts: (array([0, 1, 2]), array([64, 3, 3]))\n", - "Train set class label distribution: [0.93333333 0.03870968 0.02795699]\n", - "Test set class label distribution: [0.91428571 0.04285714 0.04285714]\n", - "Train groups and their counts:: (array([1]), array([930]))\n", - "Test groups and their counts: (array([0]), array([70]))\n", - "Train size: 930\n", - "Test size: 70\n", - "Overlapping groups in train and test splits: 10\n", + "Train y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", + "Test y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", + "Train set class label distribution: [0.92864865 0.03675676 0.03459459]\n", + "Test set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Train groups and their counts:: (array([1]), array([925]))\n", + "Test groups and their counts: (array([0]), array([75]))\n", + "Train size: 925\n", + "Test size: 75\n", + "Overlapping groups in train and test splits: 0\n", "\n" ] } @@ -593,47 +593,47 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "(array([0, 1]), array([ 70, 930]))\n", - "Full dataset class label distribution: [0.932 0.039 0.029]\n", + "(array([0, 1]), array([ 75, 925]))\n", + "Full dataset class label distribution: [0.93 0.037 0.033]\n", "\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([583, 24, 20]))\n", - "Test y and their counts: (array([0, 1, 2]), array([349, 15, 9]))\n", - "Train set class label distribution: [0.92982456 0.03827751 0.03189793]\n", - "Test set class label distribution: [0.93565684 0.04021448 0.02412869]\n", - "Train groups and their counts:: (array([2, 4, 5, 6, 8, 9]), array([122, 114, 91, 105, 99, 96]))\n", - "Test groups and their counts: (array([0, 1, 3, 7]), array([ 80, 99, 102, 92]))\n", - "Train size: 627\n", - "Test size: 373\n", + "Train y and their counts: (array([0, 1, 2]), array([587, 20, 19]))\n", + "Test y and their counts: (array([0, 1, 2]), array([343, 17, 14]))\n", + "Train set class label distribution: [0.93769968 0.03194888 0.03035144]\n", + "Test set class label distribution: [0.9171123 0.04545455 0.03743316]\n", + "Train groups and their counts:: (array([1, 2, 4, 5, 7, 9]), array([104, 114, 84, 105, 112, 107]))\n", + "Test groups and their counts: (array([0, 3, 6, 8]), array([ 79, 105, 91, 99]))\n", + "Train size: 626\n", + "Test size: 374\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([657, 26, 16]))\n", - "Test y and their counts: (array([0, 1, 2]), array([275, 13, 13]))\n", - "Train set class label distribution: [0.93991416 0.03719599 0.02288984]\n", - "Test set class label distribution: [0.91362126 0.04318937 0.04318937]\n", - "Train groups and their counts:: (array([0, 1, 2, 3, 6, 7, 8]), array([ 80, 99, 122, 102, 105, 92, 99]))\n", - "Test groups and their counts: (array([4, 5, 9]), array([114, 91, 96]))\n", - "Train size: 699\n", - "Test size: 301\n", + "Train y and their counts: (array([0, 1, 2]), array([643, 24, 23]))\n", + "Test y and their counts: (array([0, 1, 2]), array([287, 13, 10]))\n", + "Train set class label distribution: [0.93188406 0.03478261 0.03333333]\n", + "Test set class label distribution: [0.92580645 0.04193548 0.03225806]\n", + "Train groups and their counts:: (array([0, 1, 3, 5, 6, 8, 9]), array([ 79, 104, 105, 105, 91, 99, 107]))\n", + "Test groups and their counts: (array([2, 4, 7]), array([114, 84, 112]))\n", + "Train size: 690\n", + "Test size: 310\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([624, 28, 22]))\n", - "Test y and their counts: (array([0, 1, 2]), array([308, 11, 7]))\n", - "Train set class label distribution: [0.92581602 0.04154303 0.03264095]\n", - "Test set class label distribution: [0.94478528 0.03374233 0.02147239]\n", - "Train groups and their counts:: (array([0, 1, 3, 4, 5, 7, 9]), array([ 80, 99, 102, 114, 91, 92, 96]))\n", - "Test groups and their counts: (array([2, 6, 8]), array([122, 105, 99]))\n", - "Train size: 674\n", - "Test size: 326\n", + "Train y and their counts: (array([0, 1, 2]), array([630, 30, 24]))\n", + "Test y and their counts: (array([0, 1, 2]), array([300, 7, 9]))\n", + "Train set class label distribution: [0.92105263 0.04385965 0.03508772]\n", + "Test set class label distribution: [0.94936709 0.0221519 0.02848101]\n", + "Train groups and their counts:: (array([0, 2, 3, 4, 6, 7, 8]), array([ 79, 114, 105, 84, 91, 112, 99]))\n", + "Test groups and their counts: (array([1, 5, 9]), array([104, 105, 107]))\n", + "Train size: 684\n", + "Test size: 316\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -653,45 +653,45 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "(array([0, 1]), array([ 70, 930]))\n", + "(array([0, 1]), array([ 75, 925]))\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([223, 234, 238]))\n", - "Test y and their counts: (array([0, 1, 2]), array([ 95, 97, 113]))\n", - "Train set class label distribution: [0.32086331 0.33669065 0.34244604]\n", - "Test set class label distribution: [0.31147541 0.31803279 0.3704918 ]\n", - "Train groups and their counts:: (array([0, 1, 2, 3, 5, 6, 9]), array([ 80, 99, 122, 102, 91, 105, 96]))\n", - "Test groups and their counts: (array([4, 7, 8]), array([114, 92, 99]))\n", - "Train size: 695\n", - "Test size: 305\n", + "Train y and their counts: (array([0, 1, 2]), array([240, 242, 224]))\n", + "Test y and their counts: (array([0, 1, 2]), array([109, 89, 96]))\n", + "Train set class label distribution: [0.33994334 0.3427762 0.31728045]\n", + "Test set class label distribution: [0.3707483 0.30272109 0.32653061]\n", + "Train groups and their counts:: (array([0, 2, 3, 4, 5, 7, 9]), array([ 79, 114, 105, 84, 105, 112, 107]))\n", + "Test groups and their counts: (array([1, 6, 8]), array([104, 91, 99]))\n", + "Train size: 706\n", + "Test size: 294\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([204, 207, 217]))\n", - "Test y and their counts: (array([0, 1, 2]), array([114, 124, 134]))\n", - "Train set class label distribution: [0.32484076 0.32961783 0.3455414 ]\n", - "Test set class label distribution: [0.30645161 0.33333333 0.36021505]\n", - "Train groups and their counts:: (array([2, 4, 6, 7, 8, 9]), array([122, 114, 105, 92, 99, 96]))\n", - "Test groups and their counts: (array([0, 1, 3, 5]), array([ 80, 99, 102, 91]))\n", - "Train size: 628\n", - "Test size: 372\n", + "Train y and their counts: (array([0, 1, 2]), array([216, 179, 193]))\n", + "Test y and their counts: (array([0, 1, 2]), array([133, 152, 127]))\n", + "Train set class label distribution: [0.36734694 0.30442177 0.32823129]\n", + "Test set class label distribution: [0.32281553 0.36893204 0.30825243]\n", + "Train groups and their counts:: (array([1, 3, 4, 5, 6, 8]), array([104, 105, 84, 105, 91, 99]))\n", + "Test groups and their counts: (array([0, 2, 7, 9]), array([ 79, 114, 112, 107]))\n", + "Train size: 588\n", + "Test size: 412\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([209, 221, 247]))\n", - "Test y and their counts: (array([0, 1, 2]), array([109, 110, 104]))\n", - "Train set class label distribution: [0.30871492 0.32644018 0.3648449 ]\n", - "Test set class label distribution: [0.3374613 0.34055728 0.32198142]\n", - "Train groups and their counts:: (array([0, 1, 3, 4, 5, 7, 8]), array([ 80, 99, 102, 114, 91, 92, 99]))\n", - "Test groups and their counts: (array([2, 6, 9]), array([122, 105, 96]))\n", - "Train size: 677\n", - "Test size: 323\n", + "Train y and their counts: (array([0, 1, 2]), array([242, 241, 223]))\n", + "Test y and their counts: (array([0, 1, 2]), array([107, 90, 97]))\n", + "Train set class label distribution: [0.3427762 0.34135977 0.31586402]\n", + "Test set class label distribution: [0.36394558 0.30612245 0.32993197]\n", + "Train groups and their counts:: (array([0, 1, 2, 6, 7, 8, 9]), array([ 79, 104, 114, 91, 112, 99, 107]))\n", + "Test groups and their counts: (array([3, 4, 5]), array([105, 84, 105]))\n", + "Train size: 706\n", + "Test size: 294\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -709,48 +709,48 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "(array([0, 1]), array([ 70, 930]))\n", + "(array([0, 1]), array([ 75, 925]))\n", "\n", - "Full dataset class label distribution: [0.932 0.039 0.029]\n", + "Full dataset class label distribution: [0.93 0.037 0.033]\n", "\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([666, 27, 18]))\n", - "Test y and their counts: (array([0, 1, 2]), array([266, 12, 11]))\n", - "Train set class label distribution: [0.93670886 0.03797468 0.02531646]\n", - "Test set class label distribution: [0.92041522 0.04152249 0.03806228]\n", - "Train groups and their counts:: (array([0, 2, 3, 4, 6, 7, 9]), array([ 80, 122, 102, 114, 105, 92, 96]))\n", - "Test groups and their counts: (array([1, 5, 8]), array([99, 91, 99]))\n", - "Train size: 711\n", - "Test size: 289\n", + "Train y and their counts: (array([0, 1, 2]), array([608, 21, 18]))\n", + "Test y and their counts: (array([0, 1, 2]), array([322, 16, 15]))\n", + "Train set class label distribution: [0.93972179 0.0324575 0.02782071]\n", + "Test set class label distribution: [0.9121813 0.04532578 0.04249292]\n", + "Train groups and their counts:: (array([1, 2, 3, 5, 7, 9]), array([104, 114, 105, 105, 112, 107]))\n", + "Test groups and their counts: (array([0, 4, 6, 8]), array([79, 84, 91, 99]))\n", + "Train size: 647\n", + "Test size: 353\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([625, 26, 17]))\n", - "Test y and their counts: (array([0, 1, 2]), array([307, 13, 12]))\n", - "Train set class label distribution: [0.93562874 0.03892216 0.0254491 ]\n", - "Test set class label distribution: [0.9246988 0.03915663 0.03614458]\n", - "Train groups and their counts:: (array([0, 1, 3, 5, 6, 7, 8]), array([ 80, 99, 102, 91, 105, 92, 99]))\n", - "Test groups and their counts: (array([2, 4, 9]), array([122, 114, 96]))\n", - "Train size: 668\n", - "Test size: 332\n", + "Train y and their counts: (array([0, 1, 2]), array([628, 27, 20]))\n", + "Test y and their counts: (array([0, 1, 2]), array([302, 10, 13]))\n", + "Train set class label distribution: [0.93037037 0.04 0.02962963]\n", + "Test set class label distribution: [0.92923077 0.03076923 0.04 ]\n", + "Train groups and their counts:: (array([0, 3, 4, 5, 6, 7, 8]), array([ 79, 105, 84, 105, 91, 112, 99]))\n", + "Test groups and their counts: (array([1, 2, 9]), array([104, 114, 107]))\n", + "Train size: 675\n", + "Test size: 325\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([573, 25, 23]))\n", - "Test y and their counts: (array([0, 1, 2]), array([359, 14, 6]))\n", - "Train set class label distribution: [0.92270531 0.04025765 0.03703704]\n", - "Test set class label distribution: [0.94722955 0.03693931 0.01583113]\n", - "Train groups and their counts:: (array([1, 2, 4, 5, 8, 9]), array([ 99, 122, 114, 91, 99, 96]))\n", - "Test groups and their counts: (array([0, 3, 6, 7]), array([ 80, 102, 105, 92]))\n", - "Train size: 621\n", - "Test size: 379\n", + "Train y and their counts: (array([0, 1, 2]), array([624, 26, 28]))\n", + "Test y and their counts: (array([0, 1, 2]), array([306, 11, 5]))\n", + "Train set class label distribution: [0.92035398 0.03834808 0.04129794]\n", + "Test set class label distribution: [0.95031056 0.03416149 0.01552795]\n", + "Train groups and their counts:: (array([0, 1, 2, 4, 6, 8, 9]), array([ 79, 104, 114, 84, 91, 99, 107]))\n", + "Test groups and their counts: (array([3, 5, 7]), array([105, 105, 112]))\n", + "Train size: 678\n", + "Test size: 322\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -768,47 +768,47 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "(array([0, 1]), array([ 70, 930]))\n", + "(array([0, 1]), array([ 75, 925]))\n", "\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([64, 3, 3]))\n", - "Test y and their counts: (array([0, 1, 2]), array([868, 36, 26]))\n", - "Train set class label distribution: [0.91428571 0.04285714 0.04285714]\n", - "Test set class label distribution: [0.93333333 0.03870968 0.02795699]\n", - "Train groups and their counts:: (array([0]), array([70]))\n", - "Test groups and their counts: (array([1]), array([930]))\n", - "Train size: 70\n", - "Test size: 930\n", - "Overlapping groups in train and test splits: 10\n", + "Train y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", + "Test y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", + "Train set class label distribution: [0.92864865 0.03675676 0.03459459]\n", + "Test set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Train groups and their counts:: (array([1]), array([925]))\n", + "Test groups and their counts: (array([0]), array([75]))\n", + "Train size: 925\n", + "Test size: 75\n", + "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([932, 39, 29]))\n", + "Train y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", + "Test y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", + "Train set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Test set class label distribution: [0.92864865 0.03675676 0.03459459]\n", + "Train groups and their counts:: (array([0]), array([75]))\n", + "Test groups and their counts: (array([1]), array([925]))\n", + "Train size: 75\n", + "Test size: 925\n", + "Overlapping groups in train and test splits: 0\n", + "\n", + "Split: 2\n", + "Train y and their counts: (array([0, 1, 2]), array([930, 37, 33]))\n", "Test y and their counts: (array([], dtype=int64), array([], dtype=int64))\n", - "Train set class label distribution: [0.932 0.039 0.029]\n", + "Train set class label distribution: [0.93 0.037 0.033]\n", "Test set class label distribution: []\n", - "Train groups and their counts:: (array([0, 1]), array([ 70, 930]))\n", + "Train groups and their counts:: (array([0, 1]), array([ 75, 925]))\n", "Test groups and their counts: (array([], dtype=int64), array([], dtype=int64))\n", "Train size: 1000\n", "Test size: 0\n", "Overlapping groups in train and test splits: 0\n", - "\n", - "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([868, 36, 26]))\n", - "Test y and their counts: (array([0, 1, 2]), array([64, 3, 3]))\n", - "Train set class label distribution: [0.93333333 0.03870968 0.02795699]\n", - "Test set class label distribution: [0.91428571 0.04285714 0.04285714]\n", - "Train groups and their counts:: (array([1]), array([930]))\n", - "Test groups and their counts: (array([0]), array([70]))\n", - "Train size: 930\n", - "Test size: 70\n", - "Overlapping groups in train and test splits: 10\n", "\n" ] } @@ -854,47 +854,47 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Groups and their counts: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 80, 99, 122, 102, 114, 91, 105, 92, 99, 96]))\n", - "Full dataset class label distribution: [0.932 0.039 0.029]\n", + "Groups and their counts: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 79, 104, 114, 105, 84, 105, 91, 112, 99, 107]))\n", + "Full dataset class label distribution: [0.93 0.037 0.033]\n", "\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([768, 32, 24]))\n", - "Test y and their counts: (array([0, 1, 2]), array([164, 7, 5]))\n", - "Train set class label distribution: [0.93203883 0.03883495 0.02912621]\n", - "Test set class label distribution: [0.93181818 0.03977273 0.02840909]\n", - "Train groups and their counts:: (array([1, 2, 3, 4, 5, 6, 7, 8]), array([ 99, 122, 102, 114, 91, 105, 92, 99]))\n", - "Test groups and their counts: (array([0, 9]), array([80, 96]))\n", - "Train size: 824\n", - "Test size: 176\n", + "Train y and their counts: (array([0, 1, 2]), array([753, 30, 26]))\n", + "Test y and their counts: (array([0, 1, 2]), array([177, 7, 7]))\n", + "Train set class label distribution: [0.93077874 0.03708282 0.03213844]\n", + "Test set class label distribution: [0.92670157 0.03664921 0.03664921]\n", + "Train groups and their counts:: (array([1, 2, 3, 4, 5, 6, 8, 9]), array([104, 114, 105, 84, 105, 91, 99, 107]))\n", + "Test groups and their counts: (array([0, 7]), array([ 79, 112]))\n", + "Train size: 809\n", + "Test size: 191\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([732, 29, 21]))\n", - "Test y and their counts: (array([0, 1, 2]), array([200, 10, 8]))\n", - "Train set class label distribution: [0.93606138 0.0370844 0.02685422]\n", - "Test set class label distribution: [0.91743119 0.04587156 0.03669725]\n", - "Train groups and their counts:: (array([0, 1, 3, 4, 5, 6, 7, 8]), array([ 80, 99, 102, 114, 91, 105, 92, 99]))\n", - "Test groups and their counts: (array([2, 9]), array([122, 96]))\n", - "Train size: 782\n", - "Test size: 218\n", + "Train y and their counts: (array([0, 1, 2]), array([735, 28, 26]))\n", + "Test y and their counts: (array([0, 1, 2]), array([195, 9, 7]))\n", + "Train set class label distribution: [0.93155894 0.03548796 0.03295311]\n", + "Test set class label distribution: [0.92417062 0.04265403 0.03317536]\n", + "Train groups and their counts:: (array([0, 1, 2, 3, 4, 5, 6, 9]), array([ 79, 104, 114, 105, 84, 105, 91, 107]))\n", + "Test groups and their counts: (array([7, 8]), array([112, 99]))\n", + "Train size: 789\n", + "Test size: 211\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([733, 32, 22]))\n", - "Test y and their counts: (array([0, 1, 2]), array([199, 7, 7]))\n", - "Train set class label distribution: [0.93138501 0.04066074 0.02795426]\n", - "Test set class label distribution: [0.9342723 0.03286385 0.03286385]\n", - "Train groups and their counts:: (array([0, 1, 2, 3, 5, 6, 7, 9]), array([ 80, 99, 122, 102, 91, 105, 92, 96]))\n", - "Test groups and their counts: (array([4, 8]), array([114, 99]))\n", - "Train size: 787\n", - "Test size: 213\n", + "Train y and their counts: (array([0, 1, 2]), array([748, 29, 25]))\n", + "Test y and their counts: (array([0, 1, 2]), array([182, 8, 8]))\n", + "Train set class label distribution: [0.93266833 0.0361596 0.03117207]\n", + "Test set class label distribution: [0.91919192 0.04040404 0.04040404]\n", + "Train groups and their counts:: (array([0, 1, 3, 5, 6, 7, 8, 9]), array([ 79, 104, 105, 105, 91, 112, 99, 107]))\n", + "Test groups and their counts: (array([2, 4]), array([114, 84]))\n", + "Train size: 802\n", + "Test size: 198\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -938,30 +938,18 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Groups and their counts: (array([0, 1]), array([500, 500]))\n", - "Full dataset class label distribution: [0.932 0.039 0.029]\n", + "Groups and their counts: (array([0, 1]), array([512, 488]))\n", + "Full dataset class label distribution: [0.93 0.037 0.033]\n", "\n" ] }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/peptid/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:232: UserWarning: \n", - " \"Warning: All available groups are larger than the target test size. \n", - " The algorithm will still try to select a group that overshoots the target, \n", - " which may lead to a larger than requested test set, or an completely empty test set.\"\n", - " \n", - " warnings.warn(\n" - ] - }, { "ename": "RuntimeError", "evalue": "Given the dataset, no train/test split could be found. Try increasing test_size", @@ -969,7 +957,7 @@ "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 13\u001b[39m\n\u001b[32m 11\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mGroups and their counts: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgroups_counts\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 12\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFull dataset class label distribution: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp.unique(y_imbalanced,\u001b[38;5;250m \u001b[39mreturn_counts=\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[32m1\u001b[39m]/\u001b[38;5;28mlen\u001b[39m(y_imbalanced)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m13\u001b[39m \u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_index\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msgss\u001b[49m\u001b[43m.\u001b[49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_imbalanced\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroups_balanced\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 14\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mcontinue\u001b[39;49;00m\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[31]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 12\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mGroups and their counts: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgroups_counts\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 13\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFull dataset class label distribution: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp.unique(y_imbalanced,\u001b[38;5;250m \u001b[39mreturn_counts=\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[32m1\u001b[39m]/\u001b[38;5;28mlen\u001b[39m(y_imbalanced)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m \u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_index\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msgss\u001b[49m\u001b[43m.\u001b[49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_imbalanced\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroups_balanced\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 15\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mcontinue\u001b[39;49;00m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:211\u001b[39m, in \u001b[36mStratifiedGroupShuffleSplit.split\u001b[39m\u001b[34m(self, X, y, groups)\u001b[39m\n\u001b[32m 185\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34msplit\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, y, groups=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 186\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Generates indices to split data into training and test set.\u001b[39;00m\n\u001b[32m 187\u001b[39m \n\u001b[32m 188\u001b[39m \u001b[33;03m Parameters\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 209\u001b[39m \u001b[33;03m The testing set indices for that split.\u001b[39;00m\n\u001b[32m 210\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m211\u001b[39m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m._iter_indices(X, y, groups)\n", "\u001b[36mFile \u001b[39m\u001b[32m~/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:167\u001b[39m, in \u001b[36mStratifiedGroupShuffleSplit._iter_indices\u001b[39m\u001b[34m(self, X, y, groups)\u001b[39m\n\u001b[32m 161\u001b[39m test_indices = (\n\u001b[32m 162\u001b[39m np.concatenate([group_info[g_idx][\u001b[33m\"\u001b[39m\u001b[33mindices\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m g_idx \u001b[38;5;129;01min\u001b[39;00m test_groups])\n\u001b[32m 163\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m test_groups\n\u001b[32m 164\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m []\n\u001b[32m 165\u001b[39m )\n\u001b[32m 166\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(test_indices) == \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m167\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mGiven the dataset, no train/test split could be found. Try increasing test_size\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 168\u001b[39m all_indices = np.arange(n_samples)\n\u001b[32m 169\u001b[39m train_indices = np.setdiff1d(all_indices, test_indices, assume_unique=\u001b[38;5;28;01mTrue\u001b[39;00m)\n", "\u001b[31mRuntimeError\u001b[39m: Given the dataset, no train/test split could be found. Try increasing test_size" @@ -984,6 +972,7 @@ "\n", "from splitter import StratifiedGroupShuffleSplit\n", "\n", + "# We should get a RuntimeError because all the groups are larger than the defined split\n", "np.set_printoptions(legacy = '1.25')\n", "sgss = StratifiedGroupShuffleSplit(n_splits=3, test_size = 0.22, random_state=43)\n", "groups_counts = np.unique(groups_balanced, return_counts=True)\n", @@ -1005,7 +994,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -1032,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -1066,7 +1055,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -1082,7 +1071,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -1134,7 +1123,7 @@ " SLC6A4\n", " 4061\n", " FC1=CC([C@@H]2O[C@H](CC2)CN)=C(OC)C=C1\n", - " <rdkit.Chem.rdchem.Mol object at 0x1333ecc80>\n", + " <rdkit.Chem.rdchem.Mol object at 0x1341100b0>\n", " \n", " \n", " 1\n", @@ -1149,7 +1138,7 @@ " SLC6A4\n", " 4061\n", " FC1=CC(C2OC(CC2)CN)=C(OC)C=C1\n", - " <rdkit.Chem.rdchem.Mol object at 0x1333eccf0>\n", + " <rdkit.Chem.rdchem.Mol object at 0x134110120>\n", " \n", " \n", " 2\n", @@ -1164,7 +1153,7 @@ " SLC6A4\n", " 4061\n", " FC1=CC=C(C[C@H]2C[C@@H](N(CC2)CC=C)CCCNC(=O)NC...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1333ecd60>\n", + " <rdkit.Chem.rdchem.Mol object at 0x134110190>\n", " \n", " \n", " 3\n", @@ -1179,7 +1168,7 @@ " SLC6A4\n", " 4061\n", " C=1C=C(C=CC1)C2=CC(=C(N2CC(C)C)C)C(NCCCN3CCN(C...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1333ecdd0>\n", + " <rdkit.Chem.rdchem.Mol object at 0x134110200>\n", " \n", " \n", " 4\n", @@ -1194,7 +1183,7 @@ " SLC6A4\n", " 4061\n", " C1=CC=C2C(=C1)C=C(C(N(C3CCNCC3)C4CCC4)=O)C=C2\n", - " <rdkit.Chem.rdchem.Mol object at 0x1333ece40>\n", + " <rdkit.Chem.rdchem.Mol object at 0x134110270>\n", " \n", " \n", " ...\n", @@ -1224,7 +1213,7 @@ " SLC6A4\n", " 4061\n", " C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@@H](CC4...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1334c80b0>\n", + " <rdkit.Chem.rdchem.Mol object at 0x1341df450>\n", " \n", " \n", " 7224\n", @@ -1239,7 +1228,7 @@ " SLC6A4\n", " 4061\n", " C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@H](CC4=...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1334c8120>\n", + " <rdkit.Chem.rdchem.Mol object at 0x1341df4c0>\n", " \n", " \n", " 7225\n", @@ -1254,7 +1243,7 @@ " SLC6A4\n", " 4061\n", " C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@H](CC4=...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1334c8190>\n", + " <rdkit.Chem.rdchem.Mol object at 0x1341df530>\n", " \n", " \n", " 7226\n", @@ -1269,7 +1258,7 @@ " SLC6A4\n", " 4061\n", " FC1=CC=C(C[C@H]2C[C@@H](N(CC2)C(=O)C)CCCNC(=O)...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1334c8200>\n", + " <rdkit.Chem.rdchem.Mol object at 0x1341df5a0>\n", " \n", " \n", " 7227\n", @@ -1284,7 +1273,7 @@ " SLC6A4\n", " 4061\n", " C1CCCCC1(C2=CC=C(C(=C2)Cl)Cl)CN(CC)C\n", - " <rdkit.Chem.rdchem.Mol object at 0x1334c8270>\n", + " <rdkit.Chem.rdchem.Mol object at 0x1341df610>\n", " \n", " \n", "\n", @@ -1332,22 +1321,22 @@ "7227 4061 C1CCCCC1(C2=CC=C(C(=C2)Cl)Cl)CN(CC)C \n", "\n", " ROMol \n", - "0 \n", - "1 \n", - "2 \n", - "3 \n", - "4 \n", + "0 \n", + "1 \n", + "2 \n", + "3 \n", + "4 \n", "... ... \n", - "7223 \n", - "7224 \n", - "7225 \n", - "7226 \n", - "7227 \n", + "7223 \n", + "7224 \n", + "7225 \n", + "7226 \n", + "7227 \n", "\n", "[7228 rows x 12 columns]" ] }, - "execution_count": 28, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -1419,12 +1408,12 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from splitter import train_test_group_split\n", - "x_train, x_test, y_train, y_test, groups_train, groups_test = train_test_group_split(data.ROMol, data.pXC50, data.scaffold_ID, stratify=True)" + "x_train, x_test, y_train, y_test, groups_train, groups_test = train_test_group_split(data.ROMol, data.pXC50, groups = data.scaffold_ID, stratify=True)" ] }, { From 106b42ab6c15368cd57c91ebd2235dd18b53a7cf Mon Sep 17 00:00:00 2001 From: batukav Date: Wed, 27 Aug 2025 14:02:08 +0200 Subject: [PATCH 4/4] update notebook --- ...upShuffleSplit_and_MurckoTransformer.ipynb | 764 ++++++++---------- 1 file changed, 350 insertions(+), 414 deletions(-) diff --git a/scikit_mol/notebooks/StratifiedGroupShuffleSplit_and_MurckoTransformer.ipynb b/scikit_mol/notebooks/StratifiedGroupShuffleSplit_and_MurckoTransformer.ipynb index 86674b8..51bab74 100644 --- a/scikit_mol/notebooks/StratifiedGroupShuffleSplit_and_MurckoTransformer.ipynb +++ b/scikit_mol/notebooks/StratifiedGroupShuffleSplit_and_MurckoTransformer.ipynb @@ -303,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -312,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -327,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -389,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -419,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -431,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -447,35 +447,35 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([349, 331, 320]))\n", - "Unique groups and their counts in the full dataset: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 79, 104, 114, 105, 84, 105, 91, 112, 99, 107]))\n", + "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([344, 334, 322]))\n", + "Unique groups and their counts in the full dataset: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([114, 96, 77, 113, 91, 112, 98, 88, 103, 108]))\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([265, 275, 256]))\n", - "Test y and their counts: (array([0, 1, 2]), array([84, 56, 64]))\n", - "Train set class label distribution: [0.33291457 0.34547739 0.32160804]\n", - "Test set class label distribution: [0.41176471 0.2745098 0.31372549]\n", - "Train groups and their counts:: (array([0, 1, 2, 3, 4, 6, 7, 9]), array([ 79, 104, 114, 105, 84, 91, 112, 107]))\n", - "Test groups and their counts: (array([5, 8]), array([105, 99]))\n", - "Train size: 796\n", - "Test size: 204\n", + "Train y and their counts: (array([0, 1, 2]), array([277, 275, 261]))\n", + "Test y and their counts: (array([0, 1, 2]), array([67, 59, 61]))\n", + "Train set class label distribution: [0.34071341 0.33825338 0.32103321]\n", + "Test set class label distribution: [0.35828877 0.31550802 0.32620321]\n", + "Train groups and their counts:: (array([0, 2, 3, 5, 6, 7, 8, 9]), array([114, 77, 113, 112, 98, 88, 103, 108]))\n", + "Test groups and their counts: (array([1, 4]), array([96, 91]))\n", + "Train size: 813\n", + "Test size: 187\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([272, 272, 247]))\n", - "Test y and their counts: (array([0, 1, 2]), array([77, 59, 73]))\n", - "Train set class label distribution: [0.34386852 0.34386852 0.31226296]\n", - "Test set class label distribution: [0.36842105 0.28229665 0.3492823 ]\n", - "Train groups and their counts:: (array([0, 2, 3, 4, 6, 7, 8, 9]), array([ 79, 114, 105, 84, 91, 112, 99, 107]))\n", - "Test groups and their counts: (array([1, 5]), array([104, 105]))\n", - "Train size: 791\n", - "Test size: 209\n", + "Train y and their counts: (array([0, 1, 2]), array([276, 264, 259]))\n", + "Test y and their counts: (array([0, 1, 2]), array([68, 70, 63]))\n", + "Train set class label distribution: [0.34543179 0.33041302 0.32415519]\n", + "Test set class label distribution: [0.33830846 0.34825871 0.31343284]\n", + "Train groups and their counts:: (array([0, 1, 2, 3, 4, 5, 7, 9]), array([114, 96, 77, 113, 91, 112, 88, 108]))\n", + "Test groups and their counts: (array([6, 8]), array([ 98, 103]))\n", + "Train size: 799\n", + "Test size: 201\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -492,20 +492,20 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([930, 37, 33]))\n", + "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([928, 39, 33]))\n", "Unique groups and their counts in the full dataset: (array([0, 1]), array([ 75, 925]))\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", - "Test y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", - "Train set class label distribution: [0.94666667 0.04 0.01333333]\n", - "Test set class label distribution: [0.92864865 0.03675676 0.03459459]\n", + "Train y and their counts: (array([0, 1, 2]), array([70, 3, 2]))\n", + "Test y and their counts: (array([0, 1, 2]), array([858, 36, 31]))\n", + "Train set class label distribution: [0.93333333 0.04 0.02666667]\n", + "Test set class label distribution: [0.92756757 0.03891892 0.03351351]\n", "Train groups and their counts:: (array([0]), array([75]))\n", "Test groups and their counts: (array([1]), array([925]))\n", "Train size: 75\n", @@ -513,10 +513,10 @@ "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", - "Test y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", - "Train set class label distribution: [0.92864865 0.03675676 0.03459459]\n", - "Test set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Train y and their counts: (array([0, 1, 2]), array([858, 36, 31]))\n", + "Test y and their counts: (array([0, 1, 2]), array([70, 3, 2]))\n", + "Train set class label distribution: [0.92756757 0.03891892 0.03351351]\n", + "Test set class label distribution: [0.93333333 0.04 0.02666667]\n", "Train groups and their counts:: (array([1]), array([925]))\n", "Test groups and their counts: (array([0]), array([75]))\n", "Train size: 925\n", @@ -543,20 +543,20 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([930, 37, 33]))\n", + "Unique y and their counts in the full dataset: (array([0, 1, 2]), array([928, 39, 33]))\n", "Unique groups and their counts in the full dataset: (array([0, 1]), array([ 75, 925]))\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", - "Test y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", - "Train set class label distribution: [0.94666667 0.04 0.01333333]\n", - "Test set class label distribution: [0.92864865 0.03675676 0.03459459]\n", + "Train y and their counts: (array([0, 1, 2]), array([70, 3, 2]))\n", + "Test y and their counts: (array([0, 1, 2]), array([858, 36, 31]))\n", + "Train set class label distribution: [0.93333333 0.04 0.02666667]\n", + "Test set class label distribution: [0.92756757 0.03891892 0.03351351]\n", "Train groups and their counts:: (array([0]), array([75]))\n", "Test groups and their counts: (array([1]), array([925]))\n", "Train size: 75\n", @@ -564,10 +564,10 @@ "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", - "Test y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", - "Train set class label distribution: [0.92864865 0.03675676 0.03459459]\n", - "Test set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Train y and their counts: (array([0, 1, 2]), array([858, 36, 31]))\n", + "Test y and their counts: (array([0, 1, 2]), array([70, 3, 2]))\n", + "Train set class label distribution: [0.92756757 0.03891892 0.03351351]\n", + "Test set class label distribution: [0.93333333 0.04 0.02666667]\n", "Train groups and their counts:: (array([1]), array([925]))\n", "Test groups and their counts: (array([0]), array([75]))\n", "Train size: 925\n", @@ -593,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -601,39 +601,39 @@ "output_type": "stream", "text": [ "(array([0, 1]), array([ 75, 925]))\n", - "Full dataset class label distribution: [0.93 0.037 0.033]\n", + "Full dataset class label distribution: [0.928 0.039 0.033]\n", "\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([587, 20, 19]))\n", - "Test y and their counts: (array([0, 1, 2]), array([343, 17, 14]))\n", - "Train set class label distribution: [0.93769968 0.03194888 0.03035144]\n", - "Test set class label distribution: [0.9171123 0.04545455 0.03743316]\n", - "Train groups and their counts:: (array([1, 2, 4, 5, 7, 9]), array([104, 114, 84, 105, 112, 107]))\n", - "Test groups and their counts: (array([0, 3, 6, 8]), array([ 79, 105, 91, 99]))\n", - "Train size: 626\n", - "Test size: 374\n", + "Train y and their counts: (array([0, 1, 2]), array([580, 31, 25]))\n", + "Test y and their counts: (array([0, 1, 2]), array([348, 8, 8]))\n", + "Train set class label distribution: [0.91194969 0.04874214 0.03930818]\n", + "Test set class label distribution: [0.95604396 0.02197802 0.02197802]\n", + "Train groups and their counts:: (array([0, 1, 3, 5, 6, 8]), array([114, 96, 113, 112, 98, 103]))\n", + "Test groups and their counts: (array([2, 4, 7, 9]), array([ 77, 91, 88, 108]))\n", + "Train size: 636\n", + "Test size: 364\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([643, 24, 23]))\n", - "Test y and their counts: (array([0, 1, 2]), array([287, 13, 10]))\n", - "Train set class label distribution: [0.93188406 0.03478261 0.03333333]\n", - "Test set class label distribution: [0.92580645 0.04193548 0.03225806]\n", - "Train groups and their counts:: (array([0, 1, 3, 5, 6, 8, 9]), array([ 79, 104, 105, 105, 91, 99, 107]))\n", - "Test groups and their counts: (array([2, 4, 7]), array([114, 84, 112]))\n", - "Train size: 690\n", - "Test size: 310\n", + "Train y and their counts: (array([0, 1, 2]), array([639, 24, 24]))\n", + "Test y and their counts: (array([0, 1, 2]), array([289, 15, 9]))\n", + "Train set class label distribution: [0.930131 0.0349345 0.0349345]\n", + "Test set class label distribution: [0.92332268 0.04792332 0.02875399]\n", + "Train groups and their counts:: (array([2, 3, 4, 5, 6, 7, 9]), array([ 77, 113, 91, 112, 98, 88, 108]))\n", + "Test groups and their counts: (array([0, 1, 8]), array([114, 96, 103]))\n", + "Train size: 687\n", + "Test size: 313\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([630, 30, 24]))\n", - "Test y and their counts: (array([0, 1, 2]), array([300, 7, 9]))\n", - "Train set class label distribution: [0.92105263 0.04385965 0.03508772]\n", - "Test set class label distribution: [0.94936709 0.0221519 0.02848101]\n", - "Train groups and their counts:: (array([0, 2, 3, 4, 6, 7, 8]), array([ 79, 114, 105, 84, 91, 112, 99]))\n", - "Test groups and their counts: (array([1, 5, 9]), array([104, 105, 107]))\n", - "Train size: 684\n", - "Test size: 316\n", + "Train y and their counts: (array([0, 1, 2]), array([637, 23, 17]))\n", + "Test y and their counts: (array([0, 1, 2]), array([291, 16, 16]))\n", + "Train set class label distribution: [0.94091581 0.03397341 0.02511078]\n", + "Test set class label distribution: [0.90092879 0.0495356 0.0495356 ]\n", + "Train groups and their counts:: (array([0, 1, 2, 4, 7, 8, 9]), array([114, 96, 77, 91, 88, 103, 108]))\n", + "Test groups and their counts: (array([3, 5, 6]), array([113, 112, 98]))\n", + "Train size: 677\n", + "Test size: 323\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -653,7 +653,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -662,36 +662,36 @@ "text": [ "(array([0, 1]), array([ 75, 925]))\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([240, 242, 224]))\n", - "Test y and their counts: (array([0, 1, 2]), array([109, 89, 96]))\n", - "Train set class label distribution: [0.33994334 0.3427762 0.31728045]\n", - "Test set class label distribution: [0.3707483 0.30272109 0.32653061]\n", - "Train groups and their counts:: (array([0, 2, 3, 4, 5, 7, 9]), array([ 79, 114, 105, 84, 105, 112, 107]))\n", - "Test groups and their counts: (array([1, 6, 8]), array([104, 91, 99]))\n", - "Train size: 706\n", - "Test size: 294\n", + "Train y and their counts: (array([0, 1, 2]), array([234, 241, 217]))\n", + "Test y and their counts: (array([0, 1, 2]), array([110, 93, 105]))\n", + "Train set class label distribution: [0.33815029 0.3482659 0.31358382]\n", + "Test set class label distribution: [0.35714286 0.30194805 0.34090909]\n", + "Train groups and their counts:: (array([2, 3, 4, 5, 7, 8, 9]), array([ 77, 113, 91, 112, 88, 103, 108]))\n", + "Test groups and their counts: (array([0, 1, 6]), array([114, 96, 98]))\n", + "Train size: 692\n", + "Test size: 308\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([216, 179, 193]))\n", - "Test y and their counts: (array([0, 1, 2]), array([133, 152, 127]))\n", - "Train set class label distribution: [0.36734694 0.30442177 0.32823129]\n", - "Test set class label distribution: [0.32281553 0.36893204 0.30825243]\n", - "Train groups and their counts:: (array([1, 3, 4, 5, 6, 8]), array([104, 105, 84, 105, 91, 99]))\n", - "Test groups and their counts: (array([0, 2, 7, 9]), array([ 79, 114, 112, 107]))\n", - "Train size: 588\n", - "Test size: 412\n", + "Train y and their counts: (array([0, 1, 2]), array([217, 203, 196]))\n", + "Test y and their counts: (array([0, 1, 2]), array([127, 131, 126]))\n", + "Train set class label distribution: [0.35227273 0.32954545 0.31818182]\n", + "Test set class label distribution: [0.33072917 0.34114583 0.328125 ]\n", + "Train groups and their counts:: (array([0, 1, 5, 6, 7, 9]), array([114, 96, 112, 98, 88, 108]))\n", + "Test groups and their counts: (array([2, 3, 4, 8]), array([ 77, 113, 91, 103]))\n", + "Train size: 616\n", + "Test size: 384\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([242, 241, 223]))\n", - "Test y and their counts: (array([0, 1, 2]), array([107, 90, 97]))\n", - "Train set class label distribution: [0.3427762 0.34135977 0.31586402]\n", - "Test set class label distribution: [0.36394558 0.30612245 0.32993197]\n", - "Train groups and their counts:: (array([0, 1, 2, 6, 7, 8, 9]), array([ 79, 104, 114, 91, 112, 99, 107]))\n", - "Test groups and their counts: (array([3, 4, 5]), array([105, 84, 105]))\n", - "Train size: 706\n", - "Test size: 294\n", + "Train y and their counts: (array([0, 1, 2]), array([237, 224, 231]))\n", + "Test y and their counts: (array([0, 1, 2]), array([107, 110, 91]))\n", + "Train set class label distribution: [0.34248555 0.32369942 0.33381503]\n", + "Test set class label distribution: [0.3474026 0.35714286 0.29545455]\n", + "Train groups and their counts:: (array([0, 1, 2, 3, 4, 6, 8]), array([114, 96, 77, 113, 91, 98, 103]))\n", + "Test groups and their counts: (array([5, 7, 9]), array([112, 88, 108]))\n", + "Train size: 692\n", + "Test size: 308\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -709,7 +709,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -718,39 +718,39 @@ "text": [ "(array([0, 1]), array([ 75, 925]))\n", "\n", - "Full dataset class label distribution: [0.93 0.037 0.033]\n", + "Full dataset class label distribution: [0.928 0.039 0.033]\n", "\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([608, 21, 18]))\n", - "Test y and their counts: (array([0, 1, 2]), array([322, 16, 15]))\n", - "Train set class label distribution: [0.93972179 0.0324575 0.02782071]\n", - "Test set class label distribution: [0.9121813 0.04532578 0.04249292]\n", - "Train groups and their counts:: (array([1, 2, 3, 5, 7, 9]), array([104, 114, 105, 105, 112, 107]))\n", - "Test groups and their counts: (array([0, 4, 6, 8]), array([79, 84, 91, 99]))\n", - "Train size: 647\n", - "Test size: 353\n", + "Train y and their counts: (array([0, 1, 2]), array([644, 21, 22]))\n", + "Test y and their counts: (array([0, 1, 2]), array([284, 18, 11]))\n", + "Train set class label distribution: [0.93740902 0.03056769 0.03202329]\n", + "Test set class label distribution: [0.90734824 0.05750799 0.03514377]\n", + "Train groups and their counts:: (array([0, 1, 2, 3, 4, 7, 9]), array([114, 96, 77, 113, 91, 88, 108]))\n", + "Test groups and their counts: (array([5, 6, 8]), array([112, 98, 103]))\n", + "Train size: 687\n", + "Test size: 313\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([628, 27, 20]))\n", - "Test y and their counts: (array([0, 1, 2]), array([302, 10, 13]))\n", - "Train set class label distribution: [0.93037037 0.04 0.02962963]\n", - "Test set class label distribution: [0.92923077 0.03076923 0.04 ]\n", - "Train groups and their counts:: (array([0, 3, 4, 5, 6, 7, 8]), array([ 79, 105, 84, 105, 91, 112, 99]))\n", - "Test groups and their counts: (array([1, 2, 9]), array([104, 114, 107]))\n", - "Train size: 675\n", - "Test size: 325\n", + "Train y and their counts: (array([0, 1, 2]), array([495, 24, 16]))\n", + "Test y and their counts: (array([0, 1, 2]), array([433, 15, 17]))\n", + "Train set class label distribution: [0.92523364 0.04485981 0.02990654]\n", + "Test set class label distribution: [0.9311828 0.03225806 0.03655914]\n", + "Train groups and their counts:: (array([0, 5, 6, 8, 9]), array([114, 112, 98, 103, 108]))\n", + "Test groups and their counts: (array([1, 2, 3, 4, 7]), array([ 96, 77, 113, 91, 88]))\n", + "Train size: 535\n", + "Test size: 465\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([624, 26, 28]))\n", - "Test y and their counts: (array([0, 1, 2]), array([306, 11, 5]))\n", - "Train set class label distribution: [0.92035398 0.03834808 0.04129794]\n", - "Test set class label distribution: [0.95031056 0.03416149 0.01552795]\n", - "Train groups and their counts:: (array([0, 1, 2, 4, 6, 8, 9]), array([ 79, 104, 114, 84, 91, 99, 107]))\n", - "Test groups and their counts: (array([3, 5, 7]), array([105, 105, 112]))\n", - "Train size: 678\n", - "Test size: 322\n", + "Train y and their counts: (array([0, 1, 2]), array([717, 33, 28]))\n", + "Test y and their counts: (array([0, 1, 2]), array([211, 6, 5]))\n", + "Train set class label distribution: [0.92159383 0.04241645 0.03598972]\n", + "Test set class label distribution: [0.95045045 0.02702703 0.02252252]\n", + "Train groups and their counts:: (array([1, 2, 3, 4, 5, 6, 7, 8]), array([ 96, 77, 113, 91, 112, 98, 88, 103]))\n", + "Test groups and their counts: (array([0, 9]), array([114, 108]))\n", + "Train size: 778\n", + "Test size: 222\n", "Overlapping groups in train and test splits: 0\n", "\n" ] @@ -768,7 +768,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -778,10 +778,10 @@ "(array([0, 1]), array([ 75, 925]))\n", "\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", - "Test y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", - "Train set class label distribution: [0.92864865 0.03675676 0.03459459]\n", - "Test set class label distribution: [0.94666667 0.04 0.01333333]\n", + "Train y and their counts: (array([0, 1, 2]), array([858, 36, 31]))\n", + "Test y and their counts: (array([0, 1, 2]), array([70, 3, 2]))\n", + "Train set class label distribution: [0.92756757 0.03891892 0.03351351]\n", + "Test set class label distribution: [0.93333333 0.04 0.02666667]\n", "Train groups and their counts:: (array([1]), array([925]))\n", "Test groups and their counts: (array([0]), array([75]))\n", "Train size: 925\n", @@ -789,26 +789,26 @@ "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([71, 3, 1]))\n", - "Test y and their counts: (array([0, 1, 2]), array([859, 34, 32]))\n", - "Train set class label distribution: [0.94666667 0.04 0.01333333]\n", - "Test set class label distribution: [0.92864865 0.03675676 0.03459459]\n", - "Train groups and their counts:: (array([0]), array([75]))\n", - "Test groups and their counts: (array([1]), array([925]))\n", - "Train size: 75\n", - "Test size: 925\n", - "Overlapping groups in train and test splits: 0\n", - "\n", - "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([930, 37, 33]))\n", + "Train y and their counts: (array([0, 1, 2]), array([928, 39, 33]))\n", "Test y and their counts: (array([], dtype=int64), array([], dtype=int64))\n", - "Train set class label distribution: [0.93 0.037 0.033]\n", + "Train set class label distribution: [0.928 0.039 0.033]\n", "Test set class label distribution: []\n", "Train groups and their counts:: (array([0, 1]), array([ 75, 925]))\n", "Test groups and their counts: (array([], dtype=int64), array([], dtype=int64))\n", "Train size: 1000\n", "Test size: 0\n", "Overlapping groups in train and test splits: 0\n", + "\n", + "Split: 2\n", + "Train y and their counts: (array([0, 1, 2]), array([70, 3, 2]))\n", + "Test y and their counts: (array([0, 1, 2]), array([858, 36, 31]))\n", + "Train set class label distribution: [0.93333333 0.04 0.02666667]\n", + "Test set class label distribution: [0.92756757 0.03891892 0.03351351]\n", + "Train groups and their counts:: (array([0]), array([75]))\n", + "Test groups and their counts: (array([1]), array([925]))\n", + "Train size: 75\n", + "Test size: 925\n", + "Overlapping groups in train and test splits: 0\n", "\n" ] } @@ -854,50 +854,58 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Groups and their counts: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 79, 104, 114, 105, 84, 105, 91, 112, 99, 107]))\n", - "Full dataset class label distribution: [0.93 0.037 0.033]\n", + "Groups and their counts: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([114, 96, 77, 113, 91, 112, 98, 88, 103, 108]))\n", + "Full dataset class label distribution: [0.928 0.039 0.033]\n", "\n", "Split: 0\n", - "Train y and their counts: (array([0, 1, 2]), array([753, 30, 26]))\n", - "Test y and their counts: (array([0, 1, 2]), array([177, 7, 7]))\n", - "Train set class label distribution: [0.93077874 0.03708282 0.03213844]\n", - "Test set class label distribution: [0.92670157 0.03664921 0.03664921]\n", - "Train groups and their counts:: (array([1, 2, 3, 4, 5, 6, 8, 9]), array([104, 114, 105, 84, 105, 91, 99, 107]))\n", - "Test groups and their counts: (array([0, 7]), array([ 79, 112]))\n", - "Train size: 809\n", - "Test size: 191\n", + "Train y and their counts: (array([0, 1, 2]), array([762, 30, 28]))\n", + "Test y and their counts: (array([0, 1, 2]), array([166, 9, 5]))\n", + "Train set class label distribution: [0.92926829 0.03658537 0.03414634]\n", + "Test set class label distribution: [0.92222222 0.05 0.02777778]\n", + "Train groups and their counts:: (array([0, 1, 3, 4, 5, 6, 7, 9]), array([114, 96, 113, 91, 112, 98, 88, 108]))\n", + "Test groups and their counts: (array([2, 8]), array([ 77, 103]))\n", + "Train size: 820\n", + "Test size: 180\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 1\n", - "Train y and their counts: (array([0, 1, 2]), array([735, 28, 26]))\n", - "Test y and their counts: (array([0, 1, 2]), array([195, 9, 7]))\n", - "Train set class label distribution: [0.93155894 0.03548796 0.03295311]\n", - "Test set class label distribution: [0.92417062 0.04265403 0.03317536]\n", - "Train groups and their counts:: (array([0, 1, 2, 3, 4, 5, 6, 9]), array([ 79, 104, 114, 105, 84, 105, 91, 107]))\n", - "Test groups and their counts: (array([7, 8]), array([112, 99]))\n", - "Train size: 789\n", - "Test size: 211\n", + "Train y and their counts: (array([0, 1, 2]), array([738, 34, 26]))\n", + "Test y and their counts: (array([0, 1, 2]), array([190, 5, 7]))\n", + "Train set class label distribution: [0.92481203 0.04260652 0.03258145]\n", + "Test set class label distribution: [0.94059406 0.02475248 0.03465347]\n", + "Train groups and their counts:: (array([1, 2, 3, 4, 5, 6, 8, 9]), array([ 96, 77, 113, 91, 112, 98, 103, 108]))\n", + "Test groups and their counts: (array([0, 7]), array([114, 88]))\n", + "Train size: 798\n", + "Test size: 202\n", "Overlapping groups in train and test splits: 0\n", "\n", "Split: 2\n", - "Train y and their counts: (array([0, 1, 2]), array([748, 29, 25]))\n", - "Test y and their counts: (array([0, 1, 2]), array([182, 8, 8]))\n", - "Train set class label distribution: [0.93266833 0.0361596 0.03117207]\n", - "Test set class label distribution: [0.91919192 0.04040404 0.04040404]\n", - "Train groups and their counts:: (array([0, 1, 3, 5, 6, 7, 8, 9]), array([ 79, 104, 105, 105, 91, 112, 99, 107]))\n", - "Test groups and their counts: (array([2, 4]), array([114, 84]))\n", - "Train size: 802\n", - "Test size: 198\n", + "Train y and their counts: (array([0, 1, 2]), array([596, 30, 22]))\n", + "Test y and their counts: (array([0, 1, 2]), array([848, 21, 20]))\n", + "Train set class label distribution: [0.91975309 0.0462963 0.03395062]\n", + "Test set class label distribution: [0.95388076 0.02362205 0.02249719]\n", + "Train groups and their counts:: (array([0, 3, 5, 6, 8, 9]), array([114, 113, 112, 98, 103, 108]))\n", + "Test groups and their counts: (array([1, 2, 4, 7]), array([ 96, 77, 364, 352]))\n", + "Train size: 648\n", + "Test size: 889\n", "Overlapping groups in train and test splits: 0\n", "\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/peptid/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:194: UserWarning: Requested and calculated test sizes differ by 66.90%\n", + " warnings.warn(\n" + ] } ], "source": [ @@ -938,18 +946,30 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Groups and their counts: (array([0, 1]), array([512, 488]))\n", - "Full dataset class label distribution: [0.93 0.037 0.033]\n", + "Groups and their counts: (array([0, 1]), array([538, 462]))\n", + "Full dataset class label distribution: [0.928 0.039 0.033]\n", "\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/peptid/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:255: UserWarning: \n", + " \"Warning: All available groups are larger than the target test size. \n", + " The algorithm will still try to select a group that overshoots the target, \n", + " which may lead to a larger than requested test set, or an completely empty test set.\"\n", + " \n", + " warnings.warn(\n" + ] + }, { "ename": "RuntimeError", "evalue": "Given the dataset, no train/test split could be found. Try increasing test_size", @@ -957,9 +977,9 @@ "traceback": [ "\u001b[31m---------------------------------------------------------------------------\u001b[39m", "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[31]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 12\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mGroups and their counts: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgroups_counts\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 13\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFull dataset class label distribution: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp.unique(y_imbalanced,\u001b[38;5;250m \u001b[39mreturn_counts=\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[32m1\u001b[39m]/\u001b[38;5;28mlen\u001b[39m(y_imbalanced)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m \u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_index\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msgss\u001b[49m\u001b[43m.\u001b[49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_imbalanced\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroups_balanced\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 15\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mcontinue\u001b[39;49;00m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:211\u001b[39m, in \u001b[36mStratifiedGroupShuffleSplit.split\u001b[39m\u001b[34m(self, X, y, groups)\u001b[39m\n\u001b[32m 185\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34msplit\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, y, groups=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 186\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Generates indices to split data into training and test set.\u001b[39;00m\n\u001b[32m 187\u001b[39m \n\u001b[32m 188\u001b[39m \u001b[33;03m Parameters\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 209\u001b[39m \u001b[33;03m The testing set indices for that split.\u001b[39;00m\n\u001b[32m 210\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m211\u001b[39m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m._iter_indices(X, y, groups)\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:167\u001b[39m, in \u001b[36mStratifiedGroupShuffleSplit._iter_indices\u001b[39m\u001b[34m(self, X, y, groups)\u001b[39m\n\u001b[32m 161\u001b[39m test_indices = (\n\u001b[32m 162\u001b[39m np.concatenate([group_info[g_idx][\u001b[33m\"\u001b[39m\u001b[33mindices\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m g_idx \u001b[38;5;129;01min\u001b[39;00m test_groups])\n\u001b[32m 163\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m test_groups\n\u001b[32m 164\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m []\n\u001b[32m 165\u001b[39m )\n\u001b[32m 166\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(test_indices) == \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m167\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mGiven the dataset, no train/test split could be found. Try increasing test_size\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 168\u001b[39m all_indices = np.arange(n_samples)\n\u001b[32m 169\u001b[39m train_indices = np.setdiff1d(all_indices, test_indices, assume_unique=\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[24]\u001b[39m\u001b[32m, line 14\u001b[39m\n\u001b[32m 12\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mGroups and their counts: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgroups_counts\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 13\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFull dataset class label distribution: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp.unique(y_imbalanced,\u001b[38;5;250m \u001b[39mreturn_counts=\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[32m1\u001b[39m]/\u001b[38;5;28mlen\u001b[39m(y_imbalanced)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m \u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_index\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_index\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msgss\u001b[49m\u001b[43m.\u001b[49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_imbalanced\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroups_balanced\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[32m 15\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mcontinue\u001b[39;49;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:226\u001b[39m, in \u001b[36mStratifiedGroupShuffleSplit.split\u001b[39m\u001b[34m(self, X, y, groups)\u001b[39m\n\u001b[32m 200\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34msplit\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, y, groups=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 201\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Generates indices to split data into training and test set.\u001b[39;00m\n\u001b[32m 202\u001b[39m \n\u001b[32m 203\u001b[39m \u001b[33;03m Parameters\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 224\u001b[39m \u001b[33;03m The testing set indices for that split.\u001b[39;00m\n\u001b[32m 225\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m226\u001b[39m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m._iter_indices(X, y, groups)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Local_Documents/personal/git_repos/scikit_mol_bkav/scikit_mol/notebooks/../splitter.py:176\u001b[39m, in \u001b[36mStratifiedGroupShuffleSplit._iter_indices\u001b[39m\u001b[34m(self, X, y, groups)\u001b[39m\n\u001b[32m 170\u001b[39m test_indices = (\n\u001b[32m 171\u001b[39m np.concatenate([group_info[g_idx][\u001b[33m\"\u001b[39m\u001b[33mindices\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m g_idx \u001b[38;5;129;01min\u001b[39;00m test_groups])\n\u001b[32m 172\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m test_groups\n\u001b[32m 173\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m []\n\u001b[32m 174\u001b[39m )\n\u001b[32m 175\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(test_indices) == \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m176\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m 177\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mGiven the dataset, no train/test split could be found. Try increasing test_size\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 178\u001b[39m )\n\u001b[32m 179\u001b[39m all_indices = np.arange(n_samples)\n\u001b[32m 180\u001b[39m train_indices = np.setdiff1d(all_indices, test_indices, assume_unique=\u001b[38;5;28;01mTrue\u001b[39;00m)\n", "\u001b[31mRuntimeError\u001b[39m: Given the dataset, no train/test split could be found. Try increasing test_size" ] } @@ -994,7 +1014,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1021,11 +1041,11 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ - "full_set = True\n", + "full_set = False\n", "\n", "if full_set:\n", " csv_file = \"../../tests/data/SLC6A4_active_excape_export.csv\"\n", @@ -1055,7 +1075,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -1071,7 +1091,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 43, "metadata": {}, "outputs": [ { @@ -1096,94 +1116,46 @@ " \n", " \n", " Ambit_InchiKey\n", - " Original_Entry_ID\n", - " Entrez_ID\n", - " Activity_Flag\n", - " pXC50\n", - " DB\n", - " Original_Assay_ID\n", - " Tax_ID\n", - " Gene_Symbol\n", - " Ortholog_Group\n", " SMILES\n", + " pXC50\n", " ROMol\n", " \n", " \n", " \n", " \n", " 0\n", - " AZMKBJHIXZCVNL-BXKDBHETNA-N\n", - " 44590643\n", - " 6532\n", - " A\n", - " 5.68382\n", - " pubchem\n", - " 393260\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " FC1=CC([C@@H]2O[C@H](CC2)CN)=C(OC)C=C1\n", - " <rdkit.Chem.rdchem.Mol object at 0x1341100b0>\n", + " RBCQCVSMIQCOMN-PCQZLOAONA-N\n", + " C12C([C@@H](OC(C=3C=CC(=CC3)F)C=4C=CC(=CC4)F)C...\n", + " 6.26000\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a813f0>\n", " \n", " \n", " 1\n", - " AZMKBJHIXZCVNL-UHFFFAOYNA-N\n", - " 11492305\n", - " 6532\n", - " A\n", - " 5.16210\n", - " pubchem\n", - " 393258\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " FC1=CC(C2OC(CC2)CN)=C(OC)C=C1\n", - " <rdkit.Chem.rdchem.Mol object at 0x134110120>\n", + " ALZTYVXVRZIERJ-UHFFFAOYNA-N\n", + " O(C1=NC=C2C(CN(CC2=C1)C)C3=CC=C(OC)C=C3)CCCN(C...\n", + " 7.18046\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a1d540>\n", " \n", " \n", " 2\n", - " AZOHUEDNMOIDOC-GETDIYNLNA-N\n", - " 44419340\n", - " 6532\n", - " A\n", - " 6.66354\n", - " pubchem\n", - " 276059\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " FC1=CC=C(C[C@H]2C[C@@H](N(CC2)CC=C)CCCNC(=O)NC...\n", - " <rdkit.Chem.rdchem.Mol object at 0x134110190>\n", + " MOEMPBAHOJKXBG-MRXNPFEDNA-N\n", + " O=S(=O)(N(CC=1C=CC2=CC=CC=C2C1)[C@@H]3CCNC3)C\n", + " 7.77000\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a81460>\n", " \n", " \n", " 3\n", - " AZSKJKSQZWHDOK-VJSLDGLSNA-N\n", - " CHEMBL1080745\n", - " 6532\n", - " A\n", - " 6.96000\n", - " chembl20\n", - " 617082\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " C=1C=C(C=CC1)C2=CC(=C(N2CC(C)C)C)C(NCCCN3CCN(C...\n", - " <rdkit.Chem.rdchem.Mol object at 0x134110200>\n", + " HEKGBDCRHYILPL-QWOVJGMINA-N\n", + " C1(=C2C(CCCC2O)=NC=3C1=CC=CC3)NCC=4C=CC(=CC4)Cl\n", + " 5.24000\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a812a0>\n", " \n", " \n", " 4\n", - " AZTPZTRJVCAAMX-UHFFFAOYNA-N\n", - " CHEMBL578346\n", - " 6532\n", - " A\n", - " 8.00000\n", - " chembl20\n", - " 596934\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " C1=CC=C2C(=C1)C=C(C(N(C3CCNCC3)C4CCC4)=O)C=C2\n", - " <rdkit.Chem.rdchem.Mol object at 0x134110270>\n", + " SNNRWIBSGBMYRF-UKRRQHHQNA-N\n", + " C1NC[C@@H](C1)[C@H](OC=2C=CC(=NC2C)OC)CC(C)C\n", + " 9.12000\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a814d0>\n", " \n", " \n", " ...\n", @@ -1191,152 +1163,91 @@ " ...\n", " ...\n", " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", - " ...\n", " \n", " \n", - " 7223\n", - " ZZHKHRXDQLQSFW-HHHXNRCGNA-N\n", - " CHEMBL282380\n", - " 6532\n", - " A\n", - " 5.74000\n", - " chembl20\n", - " 532580\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@@H](CC4...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1341df450>\n", + " 195\n", + " PIKWEFAACQLYMF-UHFFFAOYNA-N\n", + " C1=CC=C2C=CC(=CC2=C1)C(N3N=NC(=N3)C=4C=CC=CC4)...\n", + " 6.60000\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a86810>\n", " \n", " \n", - " 7224\n", - " ZZHKHRXDQLQSFW-MHZLTWQENA-N\n", - " CHEMBL28149\n", - " 25553\n", - " A\n", - " 5.67000\n", - " chembl20\n", - " 198050\n", - " 10116\n", - " SLC6A4\n", - " 4061\n", - " C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@H](CC4=...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1341df4c0>\n", + " 196\n", + " AUZWJAMWJZUPHQ-UHFFFAOYNA-N\n", + " C(OC1=CC=C(C=C1)Cl)(C=2C=CC(=CC2)F)C3CNCCC3\n", + " 7.86000\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a86880>\n", " \n", " \n", - " 7225\n", - " ZZHKHRXDQLQSFW-MHZLTWQENA-N\n", - " CHEMBL28149\n", - " 6532\n", - " A\n", - " 5.66000\n", - " chembl20\n", - " 532580\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@H](CC4=...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1341df530>\n", + " 197\n", + " JCEWQICHOLLRDL-WUFINQPMNA-N\n", + " O(C1=CC=2[C@@H]3N(C[C@H](C2C=C1)C4=CC=C(N5N=CC...\n", + " 8.22185\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a868f0>\n", " \n", " \n", - " 7226\n", - " ZZJGNQRWIXQQDJ-CPLJGATDNA-N\n", - " 44419306\n", - " 6532\n", - " A\n", - " 5.26241\n", - " pubchem\n", - " 276059\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " FC1=CC=C(C[C@H]2C[C@@H](N(CC2)C(=O)C)CCCNC(=O)...\n", - " <rdkit.Chem.rdchem.Mol object at 0x1341df5a0>\n", + " 198\n", + " NGRIUVQYFBDXMT-JYAVWHMHNA-N\n", + " C1NC[C@@H]2[C@H]1[C@@]2(CCOCC)C3=CC(=C(C=C3)Cl)Cl\n", + " 9.30000\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a86960>\n", " \n", " \n", - " 7227\n", - " ZZRHLICOWQQNJL-UHFFFAOYNA-N\n", - " CHEMBL1683875\n", - " 6532\n", - " A\n", - " 6.81000\n", - " chembl20\n", - " 726926\n", - " 9606\n", - " SLC6A4\n", - " 4061\n", - " C1CCCCC1(C2=CC=C(C(=C2)Cl)Cl)CN(CC)C\n", - " <rdkit.Chem.rdchem.Mol object at 0x1341df610>\n", + " 199\n", + " ZWLWOTHDIGRTNE-UHFFFAOYNA-N\n", + " C(C1=CC=NC=C1)(C2=CC=CC=C2)C3=CC=CC=C3\n", + " 5.94000\n", + " <rdkit.Chem.rdchem.Mol object at 0x139a869d0>\n", " \n", " \n", "\n", - "

7228 rows × 12 columns

\n", + "

200 rows × 4 columns

\n", "" ], "text/plain": [ - " Ambit_InchiKey Original_Entry_ID Entrez_ID Activity_Flag \\\n", - "0 AZMKBJHIXZCVNL-BXKDBHETNA-N 44590643 6532 A \n", - "1 AZMKBJHIXZCVNL-UHFFFAOYNA-N 11492305 6532 A \n", - "2 AZOHUEDNMOIDOC-GETDIYNLNA-N 44419340 6532 A \n", - "3 AZSKJKSQZWHDOK-VJSLDGLSNA-N CHEMBL1080745 6532 A \n", - "4 AZTPZTRJVCAAMX-UHFFFAOYNA-N CHEMBL578346 6532 A \n", - "... ... ... ... ... \n", - "7223 ZZHKHRXDQLQSFW-HHHXNRCGNA-N CHEMBL282380 6532 A \n", - "7224 ZZHKHRXDQLQSFW-MHZLTWQENA-N CHEMBL28149 25553 A \n", - "7225 ZZHKHRXDQLQSFW-MHZLTWQENA-N CHEMBL28149 6532 A \n", - "7226 ZZJGNQRWIXQQDJ-CPLJGATDNA-N 44419306 6532 A \n", - "7227 ZZRHLICOWQQNJL-UHFFFAOYNA-N CHEMBL1683875 6532 A \n", - "\n", - " pXC50 DB Original_Assay_ID Tax_ID Gene_Symbol \\\n", - "0 5.68382 pubchem 393260 9606 SLC6A4 \n", - "1 5.16210 pubchem 393258 9606 SLC6A4 \n", - "2 6.66354 pubchem 276059 9606 SLC6A4 \n", - "3 6.96000 chembl20 617082 9606 SLC6A4 \n", - "4 8.00000 chembl20 596934 9606 SLC6A4 \n", - "... ... ... ... ... ... \n", - "7223 5.74000 chembl20 532580 9606 SLC6A4 \n", - "7224 5.67000 chembl20 198050 10116 SLC6A4 \n", - "7225 5.66000 chembl20 532580 9606 SLC6A4 \n", - "7226 5.26241 pubchem 276059 9606 SLC6A4 \n", - "7227 6.81000 chembl20 726926 9606 SLC6A4 \n", + " Ambit_InchiKey \\\n", + "0 RBCQCVSMIQCOMN-PCQZLOAONA-N \n", + "1 ALZTYVXVRZIERJ-UHFFFAOYNA-N \n", + "2 MOEMPBAHOJKXBG-MRXNPFEDNA-N \n", + "3 HEKGBDCRHYILPL-QWOVJGMINA-N \n", + "4 SNNRWIBSGBMYRF-UKRRQHHQNA-N \n", + ".. ... \n", + "195 PIKWEFAACQLYMF-UHFFFAOYNA-N \n", + "196 AUZWJAMWJZUPHQ-UHFFFAOYNA-N \n", + "197 JCEWQICHOLLRDL-WUFINQPMNA-N \n", + "198 NGRIUVQYFBDXMT-JYAVWHMHNA-N \n", + "199 ZWLWOTHDIGRTNE-UHFFFAOYNA-N \n", "\n", - " Ortholog_Group SMILES \\\n", - "0 4061 FC1=CC([C@@H]2O[C@H](CC2)CN)=C(OC)C=C1 \n", - "1 4061 FC1=CC(C2OC(CC2)CN)=C(OC)C=C1 \n", - "2 4061 FC1=CC=C(C[C@H]2C[C@@H](N(CC2)CC=C)CCCNC(=O)NC... \n", - "3 4061 C=1C=C(C=CC1)C2=CC(=C(N2CC(C)C)C)C(NCCCN3CCN(C... \n", - "4 4061 C1=CC=C2C(=C1)C=C(C(N(C3CCNCC3)C4CCC4)=O)C=C2 \n", - "... ... ... \n", - "7223 4061 C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@@H](CC4... \n", - "7224 4061 C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@H](CC4=... \n", - "7225 4061 C=1C=CC(C(C=2C=CC=CC2)OCCN3CCN(CC3)C[C@H](CC4=... \n", - "7226 4061 FC1=CC=C(C[C@H]2C[C@@H](N(CC2)C(=O)C)CCCNC(=O)... \n", - "7227 4061 C1CCCCC1(C2=CC=C(C(=C2)Cl)Cl)CN(CC)C \n", + " SMILES pXC50 \\\n", + "0 C12C([C@@H](OC(C=3C=CC(=CC3)F)C=4C=CC(=CC4)F)C... 6.26000 \n", + "1 O(C1=NC=C2C(CN(CC2=C1)C)C3=CC=C(OC)C=C3)CCCN(C... 7.18046 \n", + "2 O=S(=O)(N(CC=1C=CC2=CC=CC=C2C1)[C@@H]3CCNC3)C 7.77000 \n", + "3 C1(=C2C(CCCC2O)=NC=3C1=CC=CC3)NCC=4C=CC(=CC4)Cl 5.24000 \n", + "4 C1NC[C@@H](C1)[C@H](OC=2C=CC(=NC2C)OC)CC(C)C 9.12000 \n", + ".. ... ... \n", + "195 C1=CC=C2C=CC(=CC2=C1)C(N3N=NC(=N3)C=4C=CC=CC4)... 6.60000 \n", + "196 C(OC1=CC=C(C=C1)Cl)(C=2C=CC(=CC2)F)C3CNCCC3 7.86000 \n", + "197 O(C1=CC=2[C@@H]3N(C[C@H](C2C=C1)C4=CC=C(N5N=CC... 8.22185 \n", + "198 C1NC[C@@H]2[C@H]1[C@@]2(CCOCC)C3=CC(=C(C=C3)Cl)Cl 9.30000 \n", + "199 C(C1=CC=NC=C1)(C2=CC=CC=C2)C3=CC=CC=C3 5.94000 \n", "\n", - " ROMol \n", - "0 \n", - "1 \n", - "2 \n", - "3 \n", - "4 \n", - "... ... \n", - "7223 \n", - "7224 \n", - "7225 \n", - "7226 \n", - "7227 \n", + " ROMol \n", + "0 \n", + "1 \n", + "2 \n", + "3 \n", + "4 \n", + ".. ... \n", + "195 \n", + "196 \n", + "197 \n", + "198 \n", + "199 \n", "\n", - "[7228 rows x 12 columns]" + "[200 rows x 4 columns]" ] }, - "execution_count": 35, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -1360,7 +1271,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -1370,7 +1281,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -1380,7 +1291,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -1389,7 +1300,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -1399,7 +1310,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -1408,7 +1319,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 0\n", + "1 1\n", + "2 2\n", + "3 3\n", + "4 4\n", + "Name: scaffold_ID, dtype: int64" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.scaffold_ID.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -1432,7 +1368,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -1478,12 +1414,12 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "from splitter import GroupSplitCV\n", - "cv_scaffold = GroupSplitCV(n_splits=5, test_size=0.2, random_state=random_state)\n" + "cv_scaffold = GroupSplitCV(n_splits=5, test_size=0.2, random_state=random_state)" ] }, { @@ -1495,14 +1431,14 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Runtime: 236.86\n" + "Runtime: 3.21\n" ] } ], @@ -1518,16 +1454,16 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.4337796417083567" + "0.4345305510316685" ] }, - "execution_count": 38, + "execution_count": 54, "metadata": {}, "output_type": "execute_result" } @@ -1556,14 +1492,14 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Runtime: 639.33\n" + "Runtime: 14.17\n" ] } ], @@ -1586,16 +1522,16 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'morganfingerprinttransformer__fpSize': 4096, 'ridge__alpha': 8}" + "{'morganfingerprinttransformer__fpSize': 1024, 'ridge__alpha': 4}" ] }, - "execution_count": 47, + "execution_count": 56, "metadata": {}, "output_type": "execute_result" } @@ -1606,16 +1542,16 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'morganfingerprinttransformer__fpSize': 4096, 'ridge__alpha': 8}" + "{'morganfingerprinttransformer__fpSize': 2048, 'ridge__alpha': 0.1}" ] }, - "execution_count": 48, + "execution_count": 57, "metadata": {}, "output_type": "execute_result" } @@ -1636,12 +1572,12 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 58, "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -1678,7 +1614,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -1689,19 +1625,19 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " Multiple Comparison of Means - Tukey HSD, FWER=0.05 \n", - "=====================================================\n", - "group1 group2 meandiff p-adj lower upper reject\n", - "-----------------------------------------------------\n", - "random scaffold -0.3827 0.0 -0.4249 -0.3404 True\n", - "-----------------------------------------------------\n" + " Multiple Comparison of Means - Tukey HSD, FWER=0.05 \n", + "======================================================\n", + "group1 group2 meandiff p-adj lower upper reject\n", + "------------------------------------------------------\n", + "random scaffold -0.1391 0.0165 -0.2503 -0.0279 True\n", + "------------------------------------------------------\n" ] } ], @@ -1722,12 +1658,12 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 61, "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -1752,7 +1688,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 62, "metadata": {}, "outputs": [ { @@ -1785,14 +1721,14 @@ " \n", " 0\n", " scaffold\n", - " 0.293215\n", - " 0.43378\n", + " 0.508007\n", + " 0.434531\n", " \n", " \n", " 1\n", " random\n", - " 0.660820\n", - " 0.43378\n", + " 0.652592\n", + " 0.416297\n", " \n", " \n", "\n", @@ -1800,11 +1736,11 @@ ], "text/plain": [ " split type validation score test score\n", - "0 scaffold 0.293215 0.43378\n", - "1 random 0.660820 0.43378" + "0 scaffold 0.508007 0.434531\n", + "1 random 0.652592 0.416297" ] }, - "execution_count": 53, + "execution_count": 62, "metadata": {}, "output_type": "execute_result" }