77import torch
88from torch import Tensor
99from torchmetrics .classification import (
10- MultilabelF1Score ,
11- MultilabelPrecision ,
12- MultilabelRecall ,
13- MultilabelAUROC ,
14- BinaryF1Score ,
1510 BinaryAUROC ,
1611 BinaryAveragePrecision ,
12+ BinaryF1Score ,
13+ BinaryRecall ,
14+ MultilabelAUROC ,
1715 MultilabelAveragePrecision ,
16+ MultilabelF1Score ,
17+ MultilabelPrecision ,
18+ MultilabelRecall ,
19+ MultilabelSpecificity ,
1820)
21+ from torchmetrics .functional import specificity
1922
2023from chebai .callbacks .epoch_metrics import BalancedAccuracy , MacroF1
2124
@@ -130,13 +133,39 @@ def metrics_classification_multilabel(
130133 f1_micro = MacroF1 (preds .shape [1 ]).to (device = device )
131134 my_auc_roc = MultilabelAUROC (preds .shape [1 ]).to (device = device )
132135 my_av_prec = MultilabelAveragePrecision (preds .shape [1 ]).to (device = device )
136+ my_macro_specificity = MultilabelSpecificity (preds .shape [1 ], average = "macro" ).to (
137+ device = device
138+ )
139+ my_micro_specificity = MultilabelSpecificity (preds .shape [1 ], average = "micro" ).to (
140+ device = device
141+ )
142+ my_macro_sensitivity = MultilabelRecall (preds .shape [1 ], average = "macro" ).to (
143+ device = device
144+ )
145+ my_micro_sensitivity = MultilabelRecall (preds .shape [1 ], average = "micro" ).to (
146+ device = device
147+ )
133148
134149 macro_f1 = my_f1_macro (preds , labels ).cpu ().numpy ()
135150 micro_f1 = f1_micro (preds , labels ).cpu ().numpy ()
136151 auc_roc = my_auc_roc (preds , labels ).cpu ().numpy ()
137152 prc_auc = my_av_prec (preds , labels ).cpu ().numpy ()
138-
139- return auc_roc , macro_f1 , micro_f1 , bal_acc , prc_auc
153+ specificity_macro = my_macro_specificity (preds , labels ).cpu ().numpy ()
154+ specificity_micro = my_micro_specificity (preds , labels ).cpu ().numpy ()
155+ sensitivity_macro = my_macro_sensitivity (preds , labels ).cpu ().numpy ()
156+ sensitivity_micro = my_micro_sensitivity (preds , labels ).cpu ().numpy ()
157+
158+ return (
159+ auc_roc ,
160+ macro_f1 ,
161+ micro_f1 ,
162+ bal_acc ,
163+ prc_auc ,
164+ sensitivity_macro ,
165+ sensitivity_micro ,
166+ specificity_macro ,
167+ specificity_micro ,
168+ )
140169
141170
142171def metrics_classification_binary (
@@ -151,12 +180,15 @@ def metrics_classification_binary(
151180 my_f1 = BinaryF1Score ().to (device = device )
152181 my_av_prec = BinaryAveragePrecision ().to (device = device )
153182 my_bal_acc = BalancedAccuracy (preds .shape [1 ]).to (device = device )
183+ my_sensitivity = BinaryRecall ().to (device = device )
154184
155185 bal_acc = my_bal_acc (preds , labels ).cpu ().numpy ()
156186 auc_roc = my_auc_roc (preds , labels ).cpu ().numpy ()
157187 # my_auc_roc.update(preds.cpu()[:, 0], labels.cpu()[:, 0])
158188 # auc_roc = my_auc_roc.compute().numpy()
159189 f1_score = my_f1 (preds , labels ).cpu ().numpy ()
160190 prc_auc = my_av_prec (preds , labels ).cpu ().numpy ()
191+ sensitivity = my_sensitivity (preds , labels ).cpu ().numpy ()
192+ specificity_result = specificity (preds , labels , task = "binary" ).cpu ().numpy ()
161193
162- return auc_roc , f1_score , bal_acc , prc_auc
194+ return auc_roc , f1_score , bal_acc , prc_auc , sensitivity , specificity_result
0 commit comments