Skip to content

Commit acda786

Browse files
authored
Merge pull request #27 from ChEB-AI/feature/lopster
Lopster integration + bug fixes
2 parents 9a79998 + 1d90c6e commit acda786

File tree

5 files changed

+34
-9
lines changed

5 files changed

+34
-9
lines changed

chebifier/ensemble/weighted_majority_ensemble.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,23 @@ def calculate_classwise_weights(self, predicted_classes):
3333
if model.classwise_weights is None:
3434
continue
3535
for cls, weights in model.classwise_weights.items():
36+
if cls not in predicted_classes:
37+
continue
38+
ppv = (
39+
weights["TP"] / (weights["TP"] + weights["FP"])
40+
if (weights["TP"] + weights["FP"]) > 0
41+
else 1.0
42+
)
43+
npv = (
44+
weights["TN"] / (weights["TN"] + weights["FN"])
45+
if (weights["TN"] + weights["FN"]) > 0
46+
else 1.0
47+
)
3648
positive_weights[predicted_classes[cls], j] *= (
37-
weights["PPV"] * self.weighting_strength
38-
+ (1 - self.weighting_strength)
49+
ppv * self.weighting_strength + (1 - self.weighting_strength)
3950
) ** self.weighting_exponent
4051
negative_weights[predicted_classes[cls], j] *= (
41-
weights["NPV"] * self.weighting_strength
42-
+ (1 - self.weighting_strength)
52+
npv * self.weighting_strength + (1 - self.weighting_strength)
4353
) ** self.weighting_exponent
4454

4555
if self.verbose_output:

chebifier/model_registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from chebifier.prediction_models.c3p_predictor import C3PPredictor
1313
from chebifier.prediction_models.chemlog_predictor import (
1414
ChemlogAllPredictor,
15+
ChemLogLopsterClingoPredictor,
16+
ChemlogLopsterPredictor,
1517
ChemlogOrganoXCompoundPredictor,
1618
ChemlogXMolecularEntityPredictor,
1719
)
@@ -33,6 +35,8 @@
3335
"chebi_lookup": ChEBILookupPredictor,
3436
"chemlog_element": ChemlogXMolecularEntityPredictor,
3537
"chemlog_organox": ChemlogOrganoXCompoundPredictor,
38+
"lopster": ChemlogLopsterPredictor,
39+
"lopster_clingo": ChemLogLopsterClingoPredictor,
3640
"c3p": C3PPredictor,
3741
}
3842

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,22 @@ def __init__(self, model_name: str, **kwargs):
107107
self.classifier = OrganoXCompoundClassifier(chebi_graph=self.chebi_graph)
108108

109109

110+
class ChemlogLopsterPredictor(ChemlogExtraPredictor):
111+
def __init__(self, model_name: str, **kwargs):
112+
from chemlog.lopster.lopster_classifier import LopsterClassifier
113+
114+
super().__init__(model_name, **kwargs)
115+
self.classifier = LopsterClassifier()
116+
117+
118+
class ChemLogLopsterClingoPredictor(ChemlogExtraPredictor):
119+
def __init__(self, model_name: str, **kwargs):
120+
from chemlog.lopster.lopster_classifier import LopsterClingoClassifier
121+
122+
super().__init__(model_name, **kwargs)
123+
self.classifier = LopsterClingoClassifier()
124+
125+
110126
class ChemlogPeptidesPredictor(BasePredictor):
111127
def __init__(self, model_name: str, **kwargs):
112128
from chemlog.cli import CLASSIFIERS

chebifier/prediction_models/electra_predictor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def init_model(self, ckpt_path: str, **kwargs) -> "Electra":
5353
map_location=self.device,
5454
criterion=None,
5555
strict=False,
56-
metrics=dict(train=dict(), test=dict(), validation=dict()),
5756
pretrained_checkpoint=None,
5857
)
5958
model.eval()

chebifier/prediction_models/gnn_predictor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphPred":
6868
map_location=torch.device(self.device),
6969
criterion=None,
7070
strict=False,
71-
metrics=dict(train=dict(), test=dict(), validation=dict()),
72-
pretrained_checkpoint=None,
7371
)
7472
model.eval()
7573
return model
@@ -115,8 +113,6 @@ def init_model(self, ckpt_path: str, **kwargs) -> "GATGraphPred":
115113
map_location=torch.device(self.device),
116114
criterion=None,
117115
strict=False,
118-
metrics=dict(train=dict(), test=dict(), validation=dict()),
119-
pretrained_checkpoint=None,
120116
)
121117
model.eval()
122118
return model

0 commit comments

Comments
 (0)