Skip to content

Commit a4365a6

Browse files
authored
Merge pull request #25 from ChEB-AI/fix/new_pred_pipeline
Utilize new prediction pipeline
2 parents 3474ec6 + e10e129 commit a4365a6

File tree

10 files changed

+53
-306
lines changed

10 files changed

+53
-306
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ def predict_smiles_list(
271271
"resgated_0ps1g189": {
272272
"type": "resgated",
273273
"ckpt_path": "data/0ps1g189/epoch=122.ckpt",
274-
"target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt",
275274
"molecular_properties": [
276275
"chebai_graph.preprocessing.properties.AtomType",
277276
"chebai_graph.preprocessing.properties.NumAtomBonds",
@@ -289,7 +288,6 @@ def predict_smiles_list(
289288
"electra_14ko0zcf": {
290289
"type": "electra",
291290
"ckpt_path": "data/14ko0zcf/epoch=193.ckpt",
292-
"target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt",
293291
# "classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
294292
},
295293
}

chebifier/hugging_face.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def download_model_files(
2525
model_config (Dict[str, str | Dict[str, str]]): A dictionary containing:
2626
- 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname').
2727
- 'subfolder' (str): The subfolder within the repo where the files are located.
28-
- 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt_path', 'target_labels_path') to
28+
- 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt_path') to
2929
actual file names (e.g., 'electra.ckpt', 'classes.txt').
3030
3131
Returns:

chebifier/model_registry.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
ChEBILookupPredictor,
88
ChemlogPeptidesPredictor,
99
ElectraPredictor,
10-
ResGatedPredictor,
10+
GNNPredictor,
1111
)
1212
from chebifier.prediction_models.c3p_predictor import C3PPredictor
1313
from chebifier.prediction_models.chemlog_predictor import (
@@ -17,7 +17,6 @@
1717
ChemlogOrganoXCompoundPredictor,
1818
ChemlogXMolecularEntityPredictor,
1919
)
20-
from chebifier.prediction_models.gnn_predictor import GATPredictor
2120

2221
ENSEMBLES = {
2322
"mv": BaseEnsemble,
@@ -28,8 +27,8 @@
2827

2928
MODEL_TYPES = {
3029
"electra": ElectraPredictor,
31-
"resgated": ResGatedPredictor,
32-
"gat": GATPredictor,
30+
"resgated": GNNPredictor,
31+
"gat": GNNPredictor,
3332
"chemlog": ChemlogAllPredictor,
3433
"chemlog_peptides": ChemlogPeptidesPredictor,
3534
"chebi_lookup": ChEBILookupPredictor,

chebifier/model_registry.yml

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,115 +4,42 @@ electra_chebi50-3star_v244:
44
repo_id: chebai/electra_chebi50-3star_v244
55
files:
66
ckpt_path: electra_chebi50-3star_v244_x2mngani_epoch=180.ckpt
7-
target_labels_path: classes.txt
87
classwise_weights_path: electra_chebi50-3star_v244_x2mngani_epoch=180_trust_3star.json
98
gat_chebi50_v244:
109
type: gat
1110
hugging_face:
1211
repo_id: chebai/gat_chebi50_v244
1312
files:
1413
ckpt_path: gat_chebi50_v244_0nfi19qt_epoch=198.ckpt
15-
target_labels_path: classes.txt
1614
classwise_weights_path: gat_chebi50_v244_0nfi19qt_epoch=198_trust_3star.json
17-
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50GraphProperties
18-
molecular_properties:
19-
- chebai_graph.preprocessing.properties.AtomType
20-
- chebai_graph.preprocessing.properties.NumAtomBonds
21-
- chebai_graph.preprocessing.properties.AtomCharge
22-
- chebai_graph.preprocessing.properties.AtomAromaticity
23-
- chebai_graph.preprocessing.properties.AtomHybridization
24-
- chebai_graph.preprocessing.properties.AtomNumHs
25-
- chebai_graph.preprocessing.properties.BondType
26-
- chebai_graph.preprocessing.properties.BondInRing
27-
- chebai_graph.preprocessing.properties.BondAromaticity
28-
- chebai_graph.preprocessing.properties.RDKit2DNormalized
2915
gat-aug_chebi50_v244:
3016
type: gat
3117
hugging_face:
3218
repo_id: chebai/gat-aug_chebi50_v244
3319
files:
3420
ckpt_path: gat-aug_chebi50_v244_8fky8tru_epoch=192.ckpt
35-
target_labels_path: classes.txt
3621
classwise_weights_path: gat-aug_chebi50_v244_8fky8tru_epoch=192_trust_3star.json
37-
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType
38-
molecular_properties:
39-
- chebai_graph.preprocessing.properties.AtomNodeLevel
40-
# Atom Node type properties
41-
- chebai_graph.preprocessing.properties.AugAtomAromaticity
42-
- chebai_graph.preprocessing.properties.AugAtomCharge
43-
- chebai_graph.preprocessing.properties.AugAtomHybridization
44-
- chebai_graph.preprocessing.properties.AugAtomNumHs
45-
- chebai_graph.preprocessing.properties.AugAtomType
46-
- chebai_graph.preprocessing.properties.AugNumAtomBonds
47-
# FG Node type properties
48-
- chebai_graph.preprocessing.properties.AtomFunctionalGroup
49-
- chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG
50-
- chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG
51-
- chebai_graph.preprocessing.properties.IsFGAlkyl
52-
# Graph Node type properties
53-
- chebai_graph.preprocessing.properties.AugRDKit2DNormalized
54-
# Bond properties
55-
- chebai_graph.preprocessing.properties.BondLevel
56-
- chebai_graph.preprocessing.properties.AugBondAromaticity
57-
- chebai_graph.preprocessing.properties.AugBondInRing
58-
- chebai_graph.preprocessing.properties.AugBondType
5922
resgated-aug_chebi50-3star_v244:
6023
type: resgated
6124
hugging_face:
6225
repo_id: chebai/resgated-aug_chebi50-3star_v244
6326
files:
6427
ckpt_path: resgated-aug_chebi50-3star_v244_w0rhmajx_epoch=190.ckpt
65-
target_labels_path: classes.txt
6628
classwise_weights_path: resgated-aug_chebi50-3star_v244_w0rhmajx_epoch=190_trust_3star.json
67-
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType
68-
molecular_properties:
69-
- chebai_graph.preprocessing.properties.AtomNodeLevel
70-
# Atom Node type properties
71-
- chebai_graph.preprocessing.properties.AugAtomAromaticity
72-
- chebai_graph.preprocessing.properties.AugAtomCharge
73-
- chebai_graph.preprocessing.properties.AugAtomHybridization
74-
- chebai_graph.preprocessing.properties.AugAtomNumHs
75-
- chebai_graph.preprocessing.properties.AugAtomType
76-
- chebai_graph.preprocessing.properties.AugNumAtomBonds
77-
# FG Node type properties
78-
- chebai_graph.preprocessing.properties.AtomFunctionalGroup
79-
- chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG
80-
- chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG
81-
- chebai_graph.preprocessing.properties.IsFGAlkyl
82-
# Graph Node type properties
83-
- chebai_graph.preprocessing.properties.AugRDKit2DNormalized
84-
# Bond properties
85-
- chebai_graph.preprocessing.properties.BondLevel
86-
- chebai_graph.preprocessing.properties.AugBondAromaticity
87-
- chebai_graph.preprocessing.properties.AugBondInRing
88-
- chebai_graph.preprocessing.properties.AugBondType
8929
electra_chebi50_v241:
9030
type: electra
9131
hugging_face:
9232
repo_id: chebai/electra_chebi50_v241
9333
files:
9434
ckpt_path: 14ko0zcf_epoch=193.ckpt
95-
target_labels_path: classes.txt
9635
classwise_weights_path: metrics_electra_14ko0zcf_80-10-10_short.json
9736
resgated_chebi50_v241:
9837
type: resgated
9938
hugging_face:
10039
repo_id: chebai/resgated_gcn_chebi50_v241
10140
files:
10241
ckpt_path: 0ps1g189_epoch=122.ckpt
103-
target_labels_path: classes.txt
10442
classwise_weights_path: metrics_0ps1g189_80-10-10_short.json
105-
molecular_properties:
106-
- chebai_graph.preprocessing.properties.AtomType
107-
- chebai_graph.preprocessing.properties.NumAtomBonds
108-
- chebai_graph.preprocessing.properties.AtomCharge
109-
- chebai_graph.preprocessing.properties.AtomAromaticity
110-
- chebai_graph.preprocessing.properties.AtomHybridization
111-
- chebai_graph.preprocessing.properties.AtomNumHs
112-
- chebai_graph.preprocessing.properties.BondType
113-
- chebai_graph.preprocessing.properties.BondInRing
114-
- chebai_graph.preprocessing.properties.BondAromaticity
115-
- chebai_graph.preprocessing.properties.RDKit2DNormalized
11643
c3p_with_weights:
11744
type: c3p
11845
hugging_face:

chebifier/prediction_models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from .chebi_lookup import ChEBILookupPredictor
44
from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor
55
from .electra_predictor import ElectraPredictor
6-
from .gnn_predictor import ResGatedPredictor
6+
from .gnn_predictor import GNNPredictor
77

88
__all__ = [
99
"BasePredictor",
1010
"ChemlogPeptidesPredictor",
1111
"ElectraPredictor",
12-
"ResGatedPredictor",
12+
"GNNPredictor",
1313
"ChEBILookupPredictor",
1414
"ChemlogExtraPredictor",
1515
"C3PPredictor",

chebifier/prediction_models/electra_predictor.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
from typing import TYPE_CHECKING
2-
31
import numpy as np
42

53
from .nn_predictor import NNPredictor
64

7-
if TYPE_CHECKING:
8-
from chebai.models.electra import Electra
9-
105

116
def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
127
n_nodes = len(node_labels)
@@ -40,36 +35,31 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
4035

4136
class ElectraPredictor(NNPredictor):
4237
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
43-
from chebai.preprocessing.reader import ChemDataReader
44-
45-
super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
46-
print(f"Initialised Electra model {self.model_name} (device: {self.device})")
47-
48-
def init_model(self, ckpt_path: str, **kwargs) -> "Electra":
49-
from chebai.models.electra import Electra
50-
51-
model = Electra.load_from_checkpoint(
52-
ckpt_path,
53-
map_location=self.device,
54-
criterion=None,
55-
strict=False,
56-
pretrained_checkpoint=None,
38+
super().__init__(model_name, ckpt_path, **kwargs)
39+
print(
40+
f"Initialised Electra model {self.model_name} (device: {self.predictor.device})"
5741
)
58-
model.eval()
59-
return model
6042

6143
def explain_smiles(self, smiles) -> dict:
6244
from chebai.preprocessing.reader import EMBEDDING_OFFSET
6345

64-
reader = self.reader_cls()
65-
token_dict = reader.to_data(dict(features=smiles, labels=[1, 2])) # dummy label
46+
# Add dummy labels because the collate function requires them.
47+
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
48+
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
49+
# Note: With New changes from https://github.com/ChEB-AI/python-chebai/pull/130, when labels are None, it also
50+
# causes problems with `missing_labels` handling. Hence using dummy labels.
51+
dummy_labels: list = list(range(1, self.predictor._model.out_dim + 1))
52+
53+
token_dict = self.predictor._dm.reader.to_data(
54+
dict(features=smiles, labels=dummy_labels)
55+
)
6656
tokens = np.array(token_dict["features"]).astype(int).tolist()
6757
result = self.calculate_results([token_dict])
6858

6959
token_labels = (
7060
["[CLR]"]
7161
+ [None for _ in range(EMBEDDING_OFFSET - 1)]
72-
+ list(reader.cache.keys())
62+
+ list(self.predictor._dm.reader.cache.keys())
7363
)
7464

7565
graphs = [
Lines changed: 4 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,14 @@
1-
from typing import TYPE_CHECKING, Optional
2-
3-
import torch
4-
51
from .nn_predictor import NNPredictor
62

7-
if TYPE_CHECKING:
8-
from chebai_graph.models.gat import GATGraphPred
9-
from chebai_graph.models.resgated import ResGatedGraphPred
10-
113

12-
class ResGatedPredictor(NNPredictor):
4+
class GNNPredictor(NNPredictor):
135
def __init__(
146
self,
157
model_name: str,
168
ckpt_path: str,
17-
molecular_properties,
18-
dataset_cls: Optional[str] = None,
199
**kwargs,
2010
):
21-
from chebai_graph.preprocessing.datasets.chebi import (
22-
ChEBI50GraphProperties,
23-
GraphPropertiesMixIn,
24-
)
25-
from chebai_graph.preprocessing.properties import MolecularProperty
26-
27-
# molecular_properties is a list of class paths
28-
if molecular_properties is not None:
29-
properties = [self.load_class(prop)() for prop in molecular_properties]
30-
properties = sorted(
31-
properties, key=lambda prop: f"{prop.name}_{prop.encoder.name}"
32-
)
33-
else:
34-
properties = []
35-
for property in properties:
36-
property.encoder.eval = True
37-
self.molecular_properties = properties
38-
assert isinstance(self.molecular_properties, list) and all(
39-
isinstance(prop, MolecularProperty) for prop in self.molecular_properties
40-
)
41-
# TODO it should not be necessary to refer to the whole dataset class, disentangle dataset and molecule reading
42-
self.dataset_cls = (
43-
self.load_class(dataset_cls)
44-
if dataset_cls is not None
45-
else ChEBI50GraphProperties
46-
)
47-
self.dataset: Optional[GraphPropertiesMixIn] = self.dataset_cls(
48-
properties=molecular_properties
49-
)
50-
51-
super().__init__(
52-
model_name, ckpt_path, reader_cls=self.dataset.READER, **kwargs
53-
)
54-
55-
print(f"Initialised GNN model {self.model_name} (device: {self.device})")
56-
57-
def load_class(self, class_path: str):
58-
module_path, class_name = class_path.rsplit(".", 1)
59-
module = __import__(module_path, fromlist=[class_name])
60-
return getattr(module, class_name)
61-
62-
def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphPred":
63-
import torch
64-
from chebai_graph.models.resgated import ResGatedGraphPred
65-
66-
model = ResGatedGraphPred.load_from_checkpoint(
67-
ckpt_path,
68-
map_location=torch.device(self.device),
69-
criterion=None,
70-
strict=False,
71-
)
72-
model.eval()
73-
return model
74-
75-
def read_smiles(self, smiles):
76-
from chebai_graph.preprocessing.datasets.chebi import GraphPropAsPerNodeType
77-
78-
d = self.dataset.READER().to_data(
79-
dict(features=smiles, labels=[1, 2])
80-
) # dummy label
81-
property_data = d
82-
# TODO merge props into base should not be a method of a dataset (or at least static)
83-
for property in self.dataset.properties:
84-
property.encoder.eval = True
85-
property_value = self.reader.read_property(smiles, property)
86-
if property_value is None or len(property_value) == 0:
87-
encoded_value = None
88-
else:
89-
encoded_value = torch.stack(
90-
[property.encoder.encode(v) for v in property_value]
91-
)
92-
if len(encoded_value.shape) == 3:
93-
encoded_value = encoded_value.squeeze(0)
94-
property_data[property.name] = encoded_value
95-
# Augmented graphs need an additional argument
96-
if isinstance(self.dataset, GraphPropAsPerNodeType):
97-
d["features"] = self.dataset._merge_props_into_base(
98-
property_data, max_len_node_properties=self.model.gnn.in_channels
99-
)
100-
else:
101-
d["features"] = self.dataset._merge_props_into_base(property_data)
102-
return d
103-
104-
105-
class GATPredictor(ResGatedPredictor):
106-
107-
def init_model(self, ckpt_path: str, **kwargs) -> "GATGraphPred":
108-
import torch
109-
from chebai_graph.models.gat import GATGraphPred
110-
111-
model = GATGraphPred.load_from_checkpoint(
112-
ckpt_path,
113-
map_location=torch.device(self.device),
114-
criterion=None,
115-
strict=False,
11+
super().__init__(model_name, ckpt_path, **kwargs)
12+
print(
13+
f"Initialised GNN model {self.model_name} (device: {self.predictor.device})"
11614
)
117-
model.eval()
118-
return model

0 commit comments

Comments
 (0)