Skip to content

Commit 95adba2

Browse files
committed
update reader for mol objects
1 parent 5576839 commit 95adba2

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

chebai_graph/preprocessing/reader/reader.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
import chebai.preprocessing.reader as dr
4+
from chebai.preprocessing.datasets.chebi import sanitize_molecule
45
import networkx as nx
56
import pysmiles as ps
67
import rdkit.Chem as Chem
@@ -54,30 +55,33 @@ def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None:
5455
if smiles in self.mol_object_buffer:
5556
return self.mol_object_buffer[smiles]
5657

57-
mol = Chem.MolFromSmiles(smiles)
58+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
5859
if mol is None:
5960
print(f"RDKit failed to at parsing {smiles} (returned None)")
6061
self.failed_counter += 1
6162
else:
6263
try:
63-
Chem.SanitizeMol(mol)
64+
sanitize_molecule(mol)
6465
except Exception as e:
6566
print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}")
6667
self.failed_counter += 1
6768
self.mol_object_buffer[smiles] = mol
6869
return mol
6970

70-
def _read_data(self, raw_data: str) -> GeomData | None:
71+
def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
7172
"""
7273
Convert raw SMILES string data into a PyTorch Geometric Data object.
7374
7475
Args:
75-
raw_data (str): SMILES string.
76+
raw_data (str | Chem.Mol): SMILES string or RDKit molecule object.
7677
7778
Returns:
7879
GeomData | None: Graph data object or None if molecule parsing failed.
7980
"""
80-
mol = self._smiles_to_mol(raw_data)
81+
if isinstance(raw_data, Chem.Mol):
82+
mol = raw_data
83+
else:
84+
mol = self._smiles_to_mol(raw_data)
8185
if mol is None:
8286
return None
8387

@@ -144,19 +148,19 @@ def name(cls) -> str:
144148
"""
145149
return "graph"
146150

147-
def _read_data(self, raw_data: str) -> GeomData | None:
151+
def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
148152
"""
149153
Convert a SMILES string into a PyTorch Geometric Data object with atom tokens and bond order attributes.
150154
151155
Args:
152-
raw_data (str): SMILES string.
156+
raw_data (str | Chem.Mol): SMILES string or RDKit molecule object.
153157
154158
Returns:
155159
GeomData | None: Graph data object or None if parsing failed.
156160
"""
157161
# raw_data is a SMILES string
158162
try:
159-
mol = ps.read_smiles(raw_data)
163+
mol = self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data
160164
except ValueError:
161165
return None
162166
assert isinstance(mol, nx.Graph)
@@ -189,6 +193,27 @@ def _read_data(self, raw_data: str) -> GeomData | None:
189193
nx.set_edge_attributes(mol, de, "edge_attr")
190194
data = from_networkx(mol)
191195
return data
196+
197+
def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None:
198+
"""
199+
Load SMILES string into an RDKit molecule object.
200+
201+
Args:
202+
smiles (str): The SMILES string to parse.
203+
204+
Returns:
205+
Chem.rdchem.Mol | None: Parsed molecule object or None if parsing failed.
206+
"""
207+
208+
mol = Chem.MolFromSmiles(smiles, sanitize=False)
209+
if mol is None:
210+
print(f"RDKit failed to at parsing {smiles} (returned None)")
211+
else:
212+
try:
213+
sanitize_molecule(mol)
214+
except Exception as e:
215+
print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}")
216+
return mol
192217

193218
def collate(self, list_of_tuples: list) -> any:
194219
"""

0 commit comments

Comments
 (0)