Skip to content

Commit dc82d38

Browse files
authored
Merge pull request #149 from schnamo/dev
Adding more metrics to evaluation
2 parents 8428e04 + ce459ef commit dc82d38

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4373,3 +4373,5 @@ b
43734373
[CaH2]
43744374
[NH3]
43754375
[OH2]
4376+
[TlH2+]
4377+
[SbH6+3]

chebai/result/classification.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@
77
import torch
88
from torch import Tensor
99
from torchmetrics.classification import (
10-
MultilabelF1Score,
11-
MultilabelPrecision,
12-
MultilabelRecall,
13-
MultilabelAUROC,
14-
BinaryF1Score,
1510
BinaryAUROC,
1611
BinaryAveragePrecision,
12+
BinaryF1Score,
13+
BinaryRecall,
14+
MultilabelAUROC,
1715
MultilabelAveragePrecision,
16+
MultilabelF1Score,
17+
MultilabelPrecision,
18+
MultilabelRecall,
19+
MultilabelSpecificity,
1820
)
21+
from torchmetrics.functional import specificity
1922

2023
from chebai.callbacks.epoch_metrics import BalancedAccuracy, MacroF1
2124

@@ -130,13 +133,39 @@ def metrics_classification_multilabel(
130133
f1_micro = MacroF1(preds.shape[1]).to(device=device)
131134
my_auc_roc = MultilabelAUROC(preds.shape[1]).to(device=device)
132135
my_av_prec = MultilabelAveragePrecision(preds.shape[1]).to(device=device)
136+
my_macro_specificity = MultilabelSpecificity(preds.shape[1], average="macro").to(
137+
device=device
138+
)
139+
my_micro_specificity = MultilabelSpecificity(preds.shape[1], average="micro").to(
140+
device=device
141+
)
142+
my_macro_sensitivity = MultilabelRecall(preds.shape[1], average="macro").to(
143+
device=device
144+
)
145+
my_micro_sensitivity = MultilabelRecall(preds.shape[1], average="micro").to(
146+
device=device
147+
)
133148

134149
macro_f1 = my_f1_macro(preds, labels).cpu().numpy()
135150
micro_f1 = f1_micro(preds, labels).cpu().numpy()
136151
auc_roc = my_auc_roc(preds, labels).cpu().numpy()
137152
prc_auc = my_av_prec(preds, labels).cpu().numpy()
138-
139-
return auc_roc, macro_f1, micro_f1, bal_acc, prc_auc
153+
specificity_macro = my_macro_specificity(preds, labels).cpu().numpy()
154+
specificity_micro = my_micro_specificity(preds, labels).cpu().numpy()
155+
sensitivity_macro = my_macro_sensitivity(preds, labels).cpu().numpy()
156+
sensitivity_micro = my_micro_sensitivity(preds, labels).cpu().numpy()
157+
158+
return (
159+
auc_roc,
160+
macro_f1,
161+
micro_f1,
162+
bal_acc,
163+
prc_auc,
164+
sensitivity_macro,
165+
sensitivity_micro,
166+
specificity_macro,
167+
specificity_micro,
168+
)
140169

141170

142171
def metrics_classification_binary(
@@ -151,12 +180,15 @@ def metrics_classification_binary(
151180
my_f1 = BinaryF1Score().to(device=device)
152181
my_av_prec = BinaryAveragePrecision().to(device=device)
153182
my_bal_acc = BalancedAccuracy(preds.shape[1]).to(device=device)
183+
my_sensitivity = BinaryRecall().to(device=device)
154184

155185
bal_acc = my_bal_acc(preds, labels).cpu().numpy()
156186
auc_roc = my_auc_roc(preds, labels).cpu().numpy()
157187
# my_auc_roc.update(preds.cpu()[:, 0], labels.cpu()[:, 0])
158188
# auc_roc = my_auc_roc.compute().numpy()
159189
f1_score = my_f1(preds, labels).cpu().numpy()
160190
prc_auc = my_av_prec(preds, labels).cpu().numpy()
191+
sensitivity = my_sensitivity(preds, labels).cpu().numpy()
192+
specificity_result = specificity(preds, labels, task="binary").cpu().numpy()
161193

162-
return auc_roc, f1_score, bal_acc, prc_auc
194+
return auc_roc, f1_score, bal_acc, prc_auc, sensitivity, specificity_result

0 commit comments

Comments
 (0)