Skip to content

Commit 2126e68

Browse files
fix: typing.Generator in protocol
1 parent 9ada4e2 commit 2126e68

1 file changed

Lines changed: 5 additions & 7 deletions

File tree

mlquantify/evaluation/protocol.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from abc import ABC, abstractmethod
22
from logging import warning
33
import numpy as np
4-
from typing import Generator, Tuple
5-
from tqdm import tqdm
64

75
from ..utils.general import *
86

@@ -62,7 +60,7 @@ def __init__(self, batch_size, random_state=None, **kwargs):
6260
raise ValueError(f"Invalid argument {name}={value}: must be int/float or list of int/float.")
6361

6462

65-
def split(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray, np.ndarray]:
63+
def split(self, X: np.ndarray, y: np.ndarray):
6664
"""
6765
Split the data into samples for evaluation.
6866
@@ -139,7 +137,7 @@ def __init__(self, batch_size, n_prevalences, repeats=1, random_state=None):
139137
n_prevalences=n_prevalences,
140138
repeats=repeats)
141139

142-
def _iter_indices(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray]:
140+
def _iter_indices(self, X: np.ndarray, y: np.ndarray):
143141

144142
n_dim = len(np.unique(y))
145143

@@ -182,7 +180,7 @@ class NPP(Protocol):
182180
... pass
183181
"""
184182

185-
def _iter_indices(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray]:
183+
def _iter_indices(self, X: np.ndarray, y: np.ndarray):
186184

187185
for batch_size in self.batch_size:
188186
yield np.random.choice(X.shape[0], batch_size, replace=True)
@@ -226,7 +224,7 @@ def __init__(self, batch_size, n_prevalences, repeats=1, random_state=None):
226224
n_prevalences=n_prevalences,
227225
repeats=repeats)
228226

229-
def _iter_indices(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray]:
227+
def _iter_indices(self, X: np.ndarray, y: np.ndarray):
230228

231229
n_dim = len(np.unique(y))
232230

@@ -279,7 +277,7 @@ def __init__(self, batch_size, prevalences, repeats=1, random_state=None):
279277
prevalences=prevalences,
280278
repeats=repeats)
281279

282-
def _iter_indices(self, X: np.ndarray, y: np.ndarray) -> Generator[np.ndarray]:
280+
def _iter_indices(self, X: np.ndarray, y: np.ndarray):
283281

284282
for batch_size in self.batch_size:
285283
for prev in self.prevalences:

0 commit comments

Comments
 (0)