Skip to content

Commit 23edd4c

Browse files
fix matrix adjustment methods with error in linear solving
1 parent 3800a0c commit 23edd4c

1 file changed

Lines changed: 17 additions & 10 deletions

File tree

mlquantify/adjust_counting/_adjustment.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import abstractmethod
33
from scipy.optimize import minimize
44
import warnings
5+
from sklearn.metrics import confusion_matrix
56

67
from mlquantify.adjust_counting._base import BaseAdjustCount
78
from mlquantify.adjust_counting._counting import CC, PCC
@@ -208,7 +209,7 @@ def _adjust(self, predictions, train_y_pred, train_y_values):
208209
prevs_estim = self._get_estimations(predictions > priors)
209210
prevalence = self._solve_optimization(prevs_estim, priors)
210211
else:
211-
self.CM = self._compute_confusion_matrix(train_y_pred)
212+
self.CM = self._compute_confusion_matrix(train_y_pred, train_y_values)
212213
prevs_estim = self._get_estimations(predictions)
213214
prevalence = self._solve_linear(prevs_estim)
214215

@@ -389,8 +390,11 @@ class GAC(CrispLearnerQMixin, MatrixAdjustment):
389390
def __init__(self, learner=None):
390391
super().__init__(learner=learner, solver='linear')
391392

392-
def _compute_confusion_matrix(self, predictions):
393-
prev_estim = self._get_estimations(predictions)
393+
def _compute_confusion_matrix(self, predictions, y_values):
394+
self.CM = confusion_matrix(y_values, predictions, labels=self.classes_).T
395+
self.CM = self.CM.astype(float)
396+
prev_estim = self.CM.sum(axis=0)
397+
394398
for i, _ in enumerate(self.classes_):
395399
if prev_estim[i] == 0:
396400
self.CM[i, i] = 1
@@ -448,13 +452,16 @@ class GPAC(SoftLearnerQMixin, MatrixAdjustment):
448452
def __init__(self, learner=None):
449453
super().__init__(learner=learner, solver='linear')
450454

451-
def _compute_confusion_matrix(self, posteriors):
452-
prev_estim = self._get_estimations(posteriors)
453-
for i, _ in enumerate(self.classes_):
454-
if prev_estim[i] == 0:
455-
self.CM[i, i] = 1
456-
else:
457-
self.CM[:, i] /= prev_estim[i]
455+
def _compute_confusion_matrix(self, posteriors, y_values):
456+
n_classes = len(self.classes_)
457+
confusion = np.eye(n_classes)
458+
459+
for i, class_label in enumerate(self.classes_):
460+
indices = (y_values == class_label)
461+
if np.any(indices):
462+
confusion[i] = posteriors[indices].mean(axis=0)
463+
464+
self.CM = confusion.T
458465
return self.CM
459466

460467

0 commit comments

Comments
 (0)