Skip to content

Commit 66c9659

Browse files
committed
make read_properties more flexible, don't recalculate properties
1 parent ad4432f commit 66c9659

3 files changed

Lines changed: 36 additions & 29 deletions

File tree

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from lightning_utilities.core.rank_zero import rank_zero_info
1717
from torch_geometric.data.data import Data as GeomData
18+
from rdkit import Chem
1819

1920
from chebai_graph.preprocessing.properties import (
2021
AllNodeTypeProperty,
@@ -185,20 +186,23 @@ def _after_setup(self, **kwargs) -> None:
185186
super()._after_setup(**kwargs)
186187

187188
def _preprocess_smiles_for_pred(
188-
self, idx, smiles: str, model_hparams: Optional[dict] = None
189+
self, idx, raw_data: str | Chem.Mol, model_hparams: Optional[dict] = None
189190
) -> dict:
190191
"""Preprocess prediction data."""
191192
# Add dummy labels because the collate function requires them.
192193
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
193194
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
194195
result = self.reader.to_data(
195-
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
196+
{"id": f"smiles_{idx}", "features": raw_data, "labels": [1, 2]}
196197
)
198+
# _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object
199+
if isinstance(result["features"], tuple):
200+
result["features"], raw_data = result["features"][0]
197201
if result is None or result["features"] is None:
198202
return None
199203
for property in self.properties:
200204
property.encoder.eval = True
201-
property_value = self.reader.read_property(smiles, property)
205+
property_value = self.reader.read_property(raw_data, property)
202206
if property_value is None or len(property_value) == 0:
203207
encoded_value = None
204208
else:

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ def name(cls) -> str:
5959
"""
6060
return f"{cls.__name__}".lower()
6161

62-
def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
62+
def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, dict] | None:
6363
"""
6464
Reads and augments molecular data from a SMILES string.
6565
6666
Args:
6767
raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule.
6868
6969
Returns:
70-
GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges,
71-
or None if parsing or augmentation fails.
70+
tuple[GeomData, dict] | None: A tuple containing a PyTorch Geometric Data object with augmented nodes and edges,
71+
and a dictionary of augmented molecule data, or None if parsing or augmentation fails.
7272
7373
Raises:
7474
RuntimeError: If an unexpected error occurs during graph augmentation.
@@ -124,12 +124,12 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
124124
NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms()
125125
is_atom_mask[:NUM_ATOM_NODES] = True
126126

127-
return GeomData(
127+
return (GeomData(
128128
x=x,
129129
edge_index=edge_index,
130130
edge_attr=edge_attr,
131131
is_atom_node=is_atom_mask,
132-
)
132+
), augmented_molecule)
133133

134134
def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None:
135135
"""
@@ -285,32 +285,35 @@ def on_finish(self) -> None:
285285
)
286286
self.mol_object_buffer = {}
287287

288-
def read_property(self, data: str | Chem.Mol, property: MolecularProperty) -> list | None:
288+
def read_property(self, raw_data: str | Chem.Mol | dict, property: MolecularProperty) -> list | None:
289289
"""
290290
Reads a specific property from a molecule represented by a SMILES string.
291291
292292
Args:
293-
data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule.
293+
raw_data (str | Chem.Mol | dict): SMILES string, RDKit molecule object, or dictionary representation of a molecule.
294294
property (MolecularProperty): Molecular property object for which the value needs to be extracted.
295295
296296
Returns:
297297
list | None: Property values if molecule parsing is successful, else None.
298298
"""
299-
if isinstance(data, Chem.Mol):
300-
mol = data
299+
if isinstance(raw_data, dict):
300+
augmented_mol = raw_data
301301
else:
302-
smiles = data
303-
if smiles in self.mol_object_buffer:
304-
return property.get_property_value(self.mol_object_buffer[smiles])
305-
mol = self._smiles_to_mol(smiles)
306-
if mol is None:
307-
return None
302+
if isinstance(raw_data, Chem.Mol):
303+
mol = raw_data
304+
else:
305+
smiles = raw_data
306+
if smiles in self.mol_object_buffer:
307+
return property.get_property_value(self.mol_object_buffer[smiles])
308+
mol = self._smiles_to_mol(smiles)
309+
if mol is None:
310+
return None
308311

309-
returned_result = self._create_augmented_graph(mol)
310-
if returned_result is None:
311-
return None
312+
returned_result = self._create_augmented_graph(mol)
313+
if returned_result is None:
314+
return None
312315

313-
_, augmented_mol = returned_result
316+
_, augmented_mol = returned_result
314317
return property.get_property_value(augmented_mol)
315318

316319

@@ -680,14 +683,14 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
680683
Returns:
681684
Data | None: Geometric data object with is_graph_node annotation.
682685
"""
683-
geom_data = super()._read_data(raw_data)
686+
geom_data, augmented_mol = super()._read_data(raw_data)
684687
if geom_data is None:
685688
return None
686689
NUM_NODES = geom_data.x.shape[0]
687690
is_graph_node = torch.zeros(NUM_NODES, dtype=torch.bool)
688691
is_graph_node[-1] = True
689692
geom_data.is_graph_node = is_graph_node
690-
return geom_data
693+
return (geom_data, augmented_mol)
691694

692695
def _add_graph_node_and_edges_to_nodes(
693696
self,

chebai_graph/preprocessing/reader/reader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None:
6868
self.mol_object_buffer[smiles] = mol
6969
return mol
7070

71-
def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
71+
def _read_data(self, raw_data: str | Chem.Mol) -> tuple[GeomData, Chem.Mol] | None:
7272
"""
7373
Convert raw SMILES string data into a PyTorch Geometric Data object.
7474
@@ -95,7 +95,7 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
9595
# edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features]
9696
edge_attr = torch.zeros((edge_index.size(1), 0))
9797

98-
return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr)
98+
return (GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr), mol)
9999

100100
def on_finish(self) -> None:
101101
"""
@@ -104,18 +104,18 @@ def on_finish(self) -> None:
104104
print(f"Failed to read {self.failed_counter} SMILES in total")
105105
self.mol_object_buffer = {}
106106

107-
def read_property(self, smiles: str, property: MolecularProperty) -> list | None:
107+
def read_property(self, raw_data: str | Chem.Mol, property: MolecularProperty) -> list | None:
108108
"""
109109
Read a molecular property for a given SMILES string.
110110
111111
Args:
112-
smiles (str): SMILES string of the molecule.
112+
raw_data (str | Chem.Mol): SMILES string or RDKit molecule object of the molecule.
113113
property (MolecularProperty): Property extractor to apply.
114114
115115
Returns:
116116
list | None: Property values or None if molecule parsing failed.
117117
"""
118-
mol = self._smiles_to_mol(smiles)
118+
mol = self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data
119119
if mol is None:
120120
return None
121121
return property.get_property_value(mol)

0 commit comments

Comments
 (0)