|
2 | 2 | from abc import abstractmethod |
3 | 3 | from scipy.optimize import minimize |
4 | 4 | import warnings |
| 5 | +from sklearn.metrics import confusion_matrix |
5 | 6 |
|
6 | 7 | from mlquantify.adjust_counting._base import BaseAdjustCount |
7 | 8 | from mlquantify.adjust_counting._counting import CC, PCC |
@@ -208,7 +209,7 @@ def _adjust(self, predictions, train_y_pred, train_y_values): |
208 | 209 | prevs_estim = self._get_estimations(predictions > priors) |
209 | 210 | prevalence = self._solve_optimization(prevs_estim, priors) |
210 | 211 | else: |
211 | | - self.CM = self._compute_confusion_matrix(train_y_pred) |
| 212 | + self.CM = self._compute_confusion_matrix(train_y_pred, train_y_values) |
212 | 213 | prevs_estim = self._get_estimations(predictions) |
213 | 214 | prevalence = self._solve_linear(prevs_estim) |
214 | 215 |
|
@@ -389,8 +390,11 @@ class GAC(CrispLearnerQMixin, MatrixAdjustment): |
389 | 390 | def __init__(self, learner=None): |
390 | 391 | super().__init__(learner=learner, solver='linear') |
391 | 392 |
|
392 | | - def _compute_confusion_matrix(self, predictions): |
393 | | - prev_estim = self._get_estimations(predictions) |
| 393 | + def _compute_confusion_matrix(self, predictions, y_values): |
| 394 | + self.CM = confusion_matrix(y_values, predictions, labels=self.classes_).T |
| 395 | + self.CM = self.CM.astype(float) |
| 396 | + prev_estim = self.CM.sum(axis=0) |
| 397 | + |
394 | 398 | for i, _ in enumerate(self.classes_): |
395 | 399 | if prev_estim[i] == 0: |
396 | 400 | self.CM[i, i] = 1 |
@@ -448,13 +452,16 @@ class GPAC(SoftLearnerQMixin, MatrixAdjustment): |
448 | 452 | def __init__(self, learner=None): |
449 | 453 | super().__init__(learner=learner, solver='linear') |
450 | 454 |
|
451 | | - def _compute_confusion_matrix(self, posteriors): |
452 | | - prev_estim = self._get_estimations(posteriors) |
453 | | - for i, _ in enumerate(self.classes_): |
454 | | - if prev_estim[i] == 0: |
455 | | - self.CM[i, i] = 1 |
456 | | - else: |
457 | | - self.CM[:, i] /= prev_estim[i] |
| 455 | + def _compute_confusion_matrix(self, posteriors, y_values): |
| 456 | + n_classes = len(self.classes_) |
| 457 | + confusion = np.eye(n_classes) |
| 458 | + |
| 459 | + for i, class_label in enumerate(self.classes_): |
| 460 | + indices = (y_values == class_label) |
| 461 | + if np.any(indices): |
| 462 | + confusion[i] = posteriors[indices].mean(axis=0) |
| 463 | + |
| 464 | + self.CM = confusion.T |
458 | 465 | return self.CM |
459 | 466 |
|
460 | 467 |
|
|
0 commit comments