Skip to content

Commit 5576839

Browse files
committed
add option for passing mol object, gentle error handling
1 parent 63216b4 commit 5576839

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
from chebai.preprocessing.reader import DataReader
7+
from chebai.preprocessing.datasets.chebi import sanitize_molecule
78
from rdkit import Chem
89
from torch_geometric.data import Data as GeomData
910

@@ -22,7 +23,6 @@
2223
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
2324
# Order preservation is necessary to to create `is_atom_node` mask
2425

25-
2626
class _AugmentorReader(DataReader, ABC):
2727
"""
2828
Abstract base class for augmentor readers that extend ChemDataReader.
@@ -59,12 +59,12 @@ def name(cls) -> str:
5959
"""
6060
return f"{cls.__name__}".lower()
6161

62-
def _read_data(self, smiles: str) -> GeomData | None:
62+
def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
6363
"""
6464
Reads and augments molecular data from a SMILES string.
6565
6666
Args:
67-
smiles (str): SMILES representation of the molecule.
67+
raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule.
6868
6969
Returns:
7070
GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges,
@@ -73,20 +73,25 @@ def _read_data(self, smiles: str) -> GeomData | None:
7373
Raises:
7474
RuntimeError: If an unexpected error occurs during graph augmentation.
7575
"""
76-
mol = self._smiles_to_mol(smiles)
76+
if isinstance(raw_data, str):
77+
mol = self._smiles_to_mol(raw_data)
78+
smiles = raw_data
79+
else:
80+
mol = raw_data
81+
smiles = Chem.MolToSmiles(mol)
7782
if mol is None:
7883
return None
7984

8085
try:
8186
returned_result = self._create_augmented_graph(mol)
8287
except Exception as e:
83-
raise RuntimeError(
84-
f"Error has occurred for following SMILES: {smiles}\n\t {e}"
85-
) from e
88+
print(f"Failed to construct augmented graph for smiles {smiles}, Error: {e}")
89+
self.f_cnt_for_aug_graph += 1
90+
return None
8691

8792
# If the returned result is None, it indicates that the graph augmentation failed
8893
if returned_result is None:
89-
print(f"Failed to construct augmented graph for smiles {smiles}")
94+
print(f"Failed to construct augmented graph for smiles {smiles} (returned None)")
9095
self.f_cnt_for_aug_graph += 1
9196
return None
9297

@@ -136,13 +141,13 @@ def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None:
136141
Returns:
137142
Chem.Mol | None: RDKit molecule object if successful, else None.
138143
"""
139-
mol = Chem.MolFromSmiles(smiles)
144+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
140145
if mol is None:
141146
print(f"RDKit failed to parse {smiles} (returned None)")
142147
self.f_cnt_for_smiles += 1
143148
else:
144149
try:
145-
Chem.SanitizeMol(mol)
150+
mol = sanitize_molecule(mol)
146151
except Exception as e:
147152
print(f"RDKit failed at sanitizing {smiles}, Error {e}")
148153
self.f_cnt_for_smiles += 1
@@ -662,17 +667,17 @@ def add_fg_internal_edge(source_fg: int, target_fg: int) -> None:
662667
class _AddGraphNode(_AugmentorReader):
663668
"""Adds a graph-level node and connects it to selected/given nodes."""
664669

665-
def _read_data(self, smiles: str) -> GeomData | None:
670+
def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
666671
"""
667672
Reads data and adds a graph-level node annotation.
668673
669674
Args:
670-
smiles (str): SMILES string.
675+
raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule.
671676
672677
Returns:
673678
Data | None: Geometric data object with is_graph_node annotation.
674679
"""
675-
geom_data = super()._read_data(smiles)
680+
geom_data = super()._read_data(raw_data)
676681
if geom_data is None:
677682
return None
678683
NUM_NODES = geom_data.x.shape[0]
@@ -941,3 +946,4 @@ def _augment_graph_structure(
941946
return self._add_graph_node_and_edges_to_nodes(
942947
augmented_struct, atom_ids | fg_to_atoms_map.keys()
943948
)
949+

0 commit comments

Comments
 (0)