|
17 | 17 | from scipy import sparse |
18 | 18 |
|
19 | 19 | from sklearn.base import clone |
20 | | -from sklearn.datasets import make_classification |
| 20 | +from sklearn.datasets import make_classification, make_multilabel_classification # noqa |
21 | 21 | from sklearn.cluster import KMeans |
22 | 22 | from sklearn.preprocessing import label_binarize |
23 | 23 | from sklearn.utils.estimator_checks import check_estimator \ |
|
27 | 27 | from sklearn.utils.testing import set_random_state |
28 | 28 | from sklearn.utils.multiclass import type_of_target |
29 | 29 |
|
| 30 | +from imblearn.base import BaseSampler |
30 | 31 | from imblearn.over_sampling.base import BaseOverSampler |
31 | 32 | from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler |
32 | 33 | from imblearn.ensemble.base import BaseEnsembleSampler |
@@ -54,10 +55,18 @@ def _yield_sampler_checks(name, Estimator): |
54 | 55 | yield check_samplers_sample_indices |
55 | 56 |
|
56 | 57 |
|
| 58 | +def _yield_classifier_checks(name, Estimator): |
| 59 | + yield check_classifier_on_multilabel_or_multioutput_targets |
| 60 | + |
| 61 | + |
57 | 62 | def _yield_all_checks(name, estimator): |
58 | 63 | # trigger our checks if this is a SamplerMixin |
59 | 64 | if hasattr(estimator, 'fit_resample'): |
60 | | - yield from _yield_sampler_checks(name, estimator) |
| 65 | + for check in _yield_sampler_checks(name, estimator): |
| 66 | + yield check |
| 67 | + if hasattr(estimator, 'predict'): |
| 68 | + for check in _yield_classifier_checks(name, estimator): |
| 69 | + yield check |
61 | 70 |
|
62 | 71 |
|
63 | 72 | def check_estimator(Estimator, run_sampler_tests=True): |
@@ -99,7 +108,8 @@ def check_target_type(name, Estimator): |
99 | 108 | # if the target is multilabel then we should raise an error |
100 | 109 | rng = np.random.RandomState(42) |
101 | 110 | y = rng.randint(2, size=(20, 3)) |
102 | | - with pytest.raises(ValueError, match="'y' should encode the multiclass"): |
| 111 | + msg = "Multilabel and multioutput targets are not supported." |
| 112 | + with pytest.raises(ValueError, match=msg): |
103 | 113 | estimator.fit_resample(X, y) |
104 | 114 |
|
105 | 115 |
|
@@ -342,3 +352,11 @@ def check_samplers_sample_indices(name, Sampler): |
342 | 352 | assert hasattr(sampler, 'sample_indices_') is sample_indices |
343 | 353 | else: |
344 | 354 | assert not hasattr(sampler, 'sample_indices_') |
| 355 | + |
| 356 | + |
| 357 | +def check_classifier_on_multilabel_or_multioutput_targets(name, Estimator): |
| 358 | + estimator = Estimator() |
| 359 | + X, y = make_multilabel_classification(n_samples=30) |
| 360 | + msg = "Multilabel and multioutput targets are not supported." |
| 361 | + with pytest.raises(ValueError, match=msg): |
| 362 | + estimator.fit(X, y) |
0 commit comments