|
1 | 1 | import os |
2 | 2 |
|
3 | 3 | import chebai.preprocessing.reader as dr |
| 4 | +from chebai.preprocessing.datasets.chebi import sanitize_molecule |
4 | 5 | import networkx as nx |
5 | 6 | import pysmiles as ps |
6 | 7 | import rdkit.Chem as Chem |
@@ -54,30 +55,33 @@ def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: |
54 | 55 | if smiles in self.mol_object_buffer: |
55 | 56 | return self.mol_object_buffer[smiles] |
56 | 57 |
|
57 | | - mol = Chem.MolFromSmiles(smiles) |
| 58 | + mol = Chem.MolFromSmiles(smiles, sanitize=False) |
58 | 59 | if mol is None: |
59 | 60 | print(f"RDKit failed to at parsing {smiles} (returned None)") |
60 | 61 | self.failed_counter += 1 |
61 | 62 | else: |
62 | 63 | try: |
63 | | - Chem.SanitizeMol(mol) |
| 64 | + sanitize_molecule(mol) |
64 | 65 | except Exception as e: |
65 | 66 | print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") |
66 | 67 | self.failed_counter += 1 |
67 | 68 | self.mol_object_buffer[smiles] = mol |
68 | 69 | return mol |
69 | 70 |
|
70 | | - def _read_data(self, raw_data: str) -> GeomData | None: |
| 71 | + def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: |
71 | 72 | """ |
72 | 73 | Convert raw SMILES string data into a PyTorch Geometric Data object. |
73 | 74 |
|
74 | 75 | Args: |
75 | | - raw_data (str): SMILES string. |
| 76 | + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object. |
76 | 77 |
|
77 | 78 | Returns: |
78 | 79 | GeomData | None: Graph data object or None if molecule parsing failed. |
79 | 80 | """ |
80 | | - mol = self._smiles_to_mol(raw_data) |
| 81 | + if isinstance(raw_data, Chem.Mol): |
| 82 | + mol = raw_data |
| 83 | + else: |
| 84 | + mol = self._smiles_to_mol(raw_data) |
81 | 85 | if mol is None: |
82 | 86 | return None |
83 | 87 |
|
@@ -144,19 +148,19 @@ def name(cls) -> str: |
144 | 148 | """ |
145 | 149 | return "graph" |
146 | 150 |
|
147 | | - def _read_data(self, raw_data: str) -> GeomData | None: |
| 151 | + def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None: |
148 | 152 | """ |
149 | 153 | Convert a SMILES string into a PyTorch Geometric Data object with atom tokens and bond order attributes. |
150 | 154 |
|
151 | 155 | Args: |
152 | | - raw_data (str): SMILES string. |
| 156 | + raw_data (str | Chem.Mol): SMILES string or RDKit molecule object. |
153 | 157 |
|
154 | 158 | Returns: |
155 | 159 | GeomData | None: Graph data object or None if parsing failed. |
156 | 160 | """ |
157 | 161 | # raw_data is a SMILES string |
158 | 162 | try: |
159 | | - mol = ps.read_smiles(raw_data) |
| 163 | + mol = self._smiles_to_mol(raw_data) if isinstance(raw_data, str) else raw_data |
160 | 164 | except ValueError: |
161 | 165 | return None |
162 | 166 | assert isinstance(mol, nx.Graph) |
@@ -189,6 +193,27 @@ def _read_data(self, raw_data: str) -> GeomData | None: |
189 | 193 | nx.set_edge_attributes(mol, de, "edge_attr") |
190 | 194 | data = from_networkx(mol) |
191 | 195 | return data |
| 196 | + |
| 197 | + def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: |
| 198 | + """ |
| 199 | + Load SMILES string into an RDKit molecule object. |
| 200 | +
|
| 201 | + Args: |
| 202 | + smiles (str): The SMILES string to parse. |
| 203 | +
|
| 204 | + Returns: |
| 205 | + Chem.rdchem.Mol | None: Parsed molecule object or None if parsing failed. |
| 206 | + """ |
| 207 | + |
| 208 | + mol = Chem.MolFromSmiles(smiles, sanitize=False) |
| 209 | + if mol is None: |
| 210 | + print(f"RDKit failed to at parsing {smiles} (returned None)") |
| 211 | + else: |
| 212 | + try: |
| 213 | + sanitize_molecule(mol) |
| 214 | + except Exception as e: |
| 215 | + print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") |
| 216 | + return mol |
192 | 217 |
|
193 | 218 | def collate(self, list_of_tuples: list) -> any: |
194 | 219 | """ |
|
0 commit comments