|
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 |
|
| 13 | +from pytest import approx, raises |
| 14 | + |
13 | 15 | from sklearn import datasets |
14 | 16 | from sklearn import svm |
15 | 17 |
|
16 | 18 | from sklearn.preprocessing import label_binarize |
17 | 19 | from sklearn.utils.fixes import np_version |
18 | 20 | from sklearn.utils.validation import check_random_state |
19 | 21 | from sklearn.utils.testing import assert_allclose, assert_array_equal |
20 | | -from sklearn.utils.testing import assert_no_warnings, assert_raises |
21 | | -from sklearn.utils.testing import assert_warns_message, ignore_warnings |
22 | | -from sklearn.utils.testing import assert_raise_message |
| 22 | +from sklearn.utils.testing import assert_no_warnings |
| 23 | +from sklearn.utils.testing import ignore_warnings |
23 | 24 | from sklearn.metrics import accuracy_score, average_precision_score |
24 | 25 | from sklearn.metrics import brier_score_loss, cohen_kappa_score |
25 | 26 | from sklearn.metrics import jaccard_similarity_score, precision_score |
|
32 | 33 | from imblearn.metrics import make_index_balanced_accuracy |
33 | 34 | from imblearn.metrics import classification_report_imbalanced |
34 | 35 |
|
35 | | -from pytest import approx |
| 36 | +from imblearn.utils.testing import warns |
| 37 | + |
36 | 38 |
|
37 | 39 | RND_SEED = 42 |
38 | 40 | R_TOL = 1e-2 |
@@ -177,40 +179,30 @@ def test_sensitivity_specificity_error_multilabels(): |
177 | 179 | y_true_bin = label_binarize(y_true, classes=np.arange(5)) |
178 | 180 | y_pred_bin = label_binarize(y_pred, classes=np.arange(5)) |
179 | 181 |
|
180 | | - assert_raises(ValueError, sensitivity_score, y_true_bin, y_pred_bin) |
| 182 | + with raises(ValueError): |
| 183 | + sensitivity_score(y_true_bin, y_pred_bin) |
181 | 184 |
|
182 | 185 |
|
183 | 186 | @ignore_warnings |
184 | 187 | def test_sensitivity_specificity_support_errors(): |
185 | 188 | y_true, y_pred, _ = make_prediction(binary=True) |
186 | 189 |
|
187 | 190 | # Bad pos_label |
188 | | - assert_raises( |
189 | | - ValueError, |
190 | | - sensitivity_specificity_support, |
191 | | - y_true, |
192 | | - y_pred, |
193 | | - pos_label=2, |
194 | | - average='binary') |
| 191 | + with raises(ValueError): |
| 192 | + sensitivity_specificity_support(y_true, y_pred, pos_label=2, |
| 193 | + average='binary') |
195 | 194 |
|
196 | 195 | # Bad average option |
197 | | - assert_raises( |
198 | | - ValueError, |
199 | | - sensitivity_specificity_support, [0, 1, 2], [1, 2, 0], |
200 | | - average='mega') |
| 196 | + with raises(ValueError): |
| 197 | + sensitivity_specificity_support([0, 1, 2], [1, 2, 0], average='mega') |
201 | 198 |
|
202 | 199 |
|
203 | 200 | def test_sensitivity_specificity_unused_pos_label(): |
204 | 201 | # but average != 'binary'; even if data is binary |
205 | | - assert_warns_message( |
206 | | - UserWarning, |
207 | | - "Note that pos_label (set to 2) is " |
208 | | - "ignored when average != 'binary' (got 'macro'). You " |
209 | | - "may use labels=[pos_label] to specify a single " |
210 | | - "positive class.", |
211 | | - sensitivity_specificity_support, [1, 2, 1], [1, 2, 2], |
212 | | - pos_label=2, |
213 | | - average='macro') |
| 202 | + with warns(UserWarning, "use labels=\[pos_label\] to specify a single"): |
| 203 | + sensitivity_specificity_support([1, 2, 1], [1, 2, 2], |
| 204 | + pos_label=2, |
| 205 | + average='macro') |
214 | 206 |
|
215 | 207 |
|
216 | 208 | def test_geometric_mean_support_binary(): |
@@ -405,10 +397,8 @@ def test_classification_report_imbalanced_multiclass_with_unicode_label(): |
405 | 397 | u'0.15 0.44 0.19 31 red\xa2 0.42 0.90 0.55 0.57 0.63 ' |
406 | 398 | u'0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75') |
407 | 399 | if np_version[:3] < (1, 7, 0): |
408 | | - expected_message = ("NumPy < 1.7.0 does not implement" |
409 | | - " searchsorted on unicode data correctly.") |
410 | | - assert_raise_message(RuntimeError, expected_message, |
411 | | - classification_report_imbalanced, y_true, y_pred) |
| 400 | + with raises(RuntimeError, match="NumPy < 1.7.0"): |
| 401 | + classification_report_imbalanced(y_true, y_pred) |
412 | 402 | else: |
413 | 403 | report = classification_report_imbalanced(y_true, y_pred) |
414 | 404 | assert _format_report(report) == expected_report |
@@ -459,16 +449,20 @@ def test_iba_error_y_score_prob(): |
459 | 449 |
|
460 | 450 | aps = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
461 | 451 | average_precision_score) |
462 | | - assert_raises(AttributeError, aps, y_true, y_pred) |
| 452 | + with raises(AttributeError): |
| 453 | + aps(y_true, y_pred) |
463 | 454 |
|
464 | 455 | brier = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
465 | 456 | brier_score_loss) |
466 | | - assert_raises(AttributeError, brier, y_true, y_pred) |
| 457 | + with raises(AttributeError): |
| 458 | + brier(y_true, y_pred) |
467 | 459 |
|
468 | 460 | kappa = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
469 | 461 | cohen_kappa_score) |
470 | | - assert_raises(AttributeError, kappa, y_true, y_pred) |
| 462 | + with raises(AttributeError): |
| 463 | + kappa(y_true, y_pred) |
471 | 464 |
|
472 | 465 | ras = make_index_balanced_accuracy(alpha=0.5, squared=True)( |
473 | 466 | roc_auc_score) |
474 | | - assert_raises(AttributeError, ras, y_true, y_pred) |
| 467 | + with raises(AttributeError): |
| 468 | + ras(y_true, y_pred) |
0 commit comments