|
3 | 3 | from __future__ import division, print_function |
4 | 4 |
|
5 | 5 | import warnings |
6 | | - |
7 | 6 | from collections import Counter |
8 | 7 |
|
9 | 8 | import numpy as np |
10 | | - |
| 9 | +from six import string_types |
| 10 | +import sklearn |
11 | 11 | from sklearn.base import ClassifierMixin |
12 | 12 | from sklearn.ensemble import RandomForestClassifier |
13 | | -from sklearn.cross_validation import StratifiedKFold |
14 | | - |
15 | | -from six import string_types |
16 | 13 |
|
17 | 14 | from ..base import BaseBinarySampler |
18 | 15 |
|
19 | 16 |
|
| 17 | +def _get_cv_splits(X, y, cv, random_state): |
| 18 | + if hasattr(sklearn, 'model_selection'): |
| 19 | + from sklearn.model_selection import StratifiedKFold |
| 20 | + cv_iterator = StratifiedKFold( |
| 21 | + n_splits=cv, shuffle=False, random_state=random_state).split(X, y) |
| 22 | + else: |
| 23 | + from sklearn.cross_validation import StratifiedKFold |
| 24 | + cv_iterator = StratifiedKFold( |
| 25 | + y, n_folds=cv, shuffle=False, random_state=random_state) |
| 26 | + |
| 27 | + return cv_iterator |
| 28 | + |
| 29 | + |
20 | 30 | class InstanceHardnessThreshold(BaseBinarySampler): |
21 | 31 | """Class to perform under-sampling based on the instance hardness |
22 | 32 | threshold. |
@@ -225,8 +235,7 @@ def _sample(self, X, y): |
225 | 235 | """ |
226 | 236 |
|
227 | 237 | # Create the different folds |
228 | | - skf = StratifiedKFold( |
229 | | - y, n_folds=self.cv, shuffle=False, random_state=self.random_state) |
| 238 | + skf = _get_cv_splits(X, y, self.cv, self.random_state) |
230 | 239 |
|
231 | 240 | probabilities = np.zeros(y.shape[0], dtype=float) |
232 | 241 |
|
|
0 commit comments