|
1 | 1 | from abc import ABC, abstractmethod |
2 | 2 | from logging import warning |
3 | 3 | import numpy as np |
4 | | -from typing import Generator, Tuple |
5 | | -from tqdm import tqdm |
6 | 4 |
|
7 | 5 | from ..utils.general import * |
8 | 6 |
|
@@ -62,7 +60,7 @@ def __init__(self, batch_size, random_state=None, **kwargs): |
62 | 60 | raise ValueError(f"Invalid argument {name}={value}: must be int/float or list of int/float.") |
63 | 61 |
|
64 | 62 |
|
65 | | - def split(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray, np.ndarray]: |
| 63 | + def split(self, X: np.ndarray, y: np.ndarray): |
66 | 64 | """ |
67 | 65 | Split the data into samples for evaluation. |
68 | 66 |
|
@@ -139,7 +137,7 @@ def __init__(self, batch_size, n_prevalences, repeats=1, random_state=None): |
139 | 137 | n_prevalences=n_prevalences, |
140 | 138 | repeats=repeats) |
141 | 139 |
|
142 | | - def _iter_indices(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray]: |
| 140 | + def _iter_indices(self, X: np.ndarray, y: np.ndarray): |
143 | 141 |
|
144 | 142 | n_dim = len(np.unique(y)) |
145 | 143 |
|
@@ -182,7 +180,7 @@ class NPP(Protocol): |
182 | 180 | ... pass |
183 | 181 | """ |
184 | 182 |
|
185 | | - def _iter_indices(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray]: |
| 183 | + def _iter_indices(self, X: np.ndarray, y: np.ndarray): |
186 | 184 |
|
187 | 185 | for batch_size in self.batch_size: |
188 | 186 | yield np.random.choice(X.shape[0], batch_size, replace=True) |
@@ -226,7 +224,7 @@ def __init__(self, batch_size, n_prevalences, repeats=1, random_state=None): |
226 | 224 | n_prevalences=n_prevalences, |
227 | 225 | repeats=repeats) |
228 | 226 |
|
229 | | - def _iter_indices(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray]: |
| 227 | + def _iter_indices(self, X: np.ndarray, y: np.ndarray): |
230 | 228 |
|
231 | 229 | n_dim = len(np.unique(y)) |
232 | 230 |
|
@@ -279,7 +277,7 @@ def __init__(self, batch_size, prevalences, repeats=1, random_state=None): |
279 | 277 | prevalences=prevalences, |
280 | 278 | repeats=repeats) |
281 | 279 |
|
282 | | - def _iter_indices(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray]: |
| 280 | + def _iter_indices(self, X: np.ndarray, y: np.ndarray): |
283 | 281 |
|
284 | 282 | for batch_size in self.batch_size: |
285 | 283 | for prev in self.prevalences: |
|
0 commit comments