44
55import torch
66from chebai .preprocessing .reader import DataReader
7+ from chebai .preprocessing .datasets .chebi import sanitize_molecule
78from rdkit import Chem
89from torch_geometric .data import Data as GeomData
910
2223# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
2324# Order preservation is necessary to to create `is_atom_node` mask
2425
25-
2626class _AugmentorReader (DataReader , ABC ):
2727 """
2828 Abstract base class for augmentor readers that extend ChemDataReader.
@@ -59,12 +59,12 @@ def name(cls) -> str:
5959 """
6060 return f"{ cls .__name__ } " .lower ()
6161
62- def _read_data (self , smiles : str ) -> GeomData | None :
62+ def _read_data (self , raw_data : str | Chem . Mol ) -> GeomData | None :
6363 """
6464 Reads and augments molecular data from a SMILES string.
6565
6666 Args:
67- smiles (str): SMILES representation of the molecule.
67+ raw_data (str | Chem.Mol ): SMILES string or RDKit molecule object representing the molecule.
6868
6969 Returns:
7070 GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges,
@@ -73,20 +73,25 @@ def _read_data(self, smiles: str) -> GeomData | None:
7373 Raises:
7474 RuntimeError: If an unexpected error occurs during graph augmentation.
7575 """
76- mol = self ._smiles_to_mol (smiles )
76+ if isinstance (raw_data , str ):
77+ mol = self ._smiles_to_mol (raw_data )
78+ smiles = raw_data
79+ else :
80+ mol = raw_data
81+ smiles = Chem .MolToSmiles (mol )
7782 if mol is None :
7883 return None
7984
8085 try :
8186 returned_result = self ._create_augmented_graph (mol )
8287 except Exception as e :
83- raise RuntimeError (
84- f"Error has occurred for following SMILES: { smiles } \n \t { e } "
85- ) from e
88+ print ( f"Failed to construct augmented graph for smiles { smiles } , Error: { e } " )
89+ self . f_cnt_for_aug_graph += 1
90+ return None
8691
8792 # If the returned result is None, it indicates that the graph augmentation failed
8893 if returned_result is None :
89- print (f"Failed to construct augmented graph for smiles { smiles } " )
94+ print (f"Failed to construct augmented graph for smiles { smiles } (returned None) " )
9095 self .f_cnt_for_aug_graph += 1
9196 return None
9297
@@ -136,13 +141,13 @@ def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None:
136141 Returns:
137142 Chem.Mol | None: RDKit molecule object if successful, else None.
138143 """
139- mol = Chem .MolFromSmiles (smiles )
144+ mol = Chem .MolFromSmiles (smiles , sanitize = False )
140145 if mol is None :
141146 print (f"RDKit failed to parse { smiles } (returned None)" )
142147 self .f_cnt_for_smiles += 1
143148 else :
144149 try :
145- Chem . SanitizeMol (mol )
150+ mol = sanitize_molecule (mol )
146151 except Exception as e :
147152 print (f"RDKit failed at sanitizing { smiles } , Error { e } " )
148153 self .f_cnt_for_smiles += 1
@@ -662,17 +667,17 @@ def add_fg_internal_edge(source_fg: int, target_fg: int) -> None:
662667class _AddGraphNode (_AugmentorReader ):
663668 """Adds a graph-level node and connects it to selected/given nodes."""
664669
665- def _read_data (self , smiles : str ) -> GeomData | None :
670+ def _read_data (self , raw_data : str | Chem . Mol ) -> GeomData | None :
666671 """
667672 Reads data and adds a graph-level node annotation.
668673
669674 Args:
670- smiles (str): SMILES string.
675+ raw_data (str | Chem.Mol ): SMILES string or RDKit molecule object representing the molecule .
671676
672677 Returns:
673678 Data | None: Geometric data object with is_graph_node annotation.
674679 """
675- geom_data = super ()._read_data (smiles )
680+ geom_data = super ()._read_data (raw_data )
676681 if geom_data is None :
677682 return None
678683 NUM_NODES = geom_data .x .shape [0 ]
@@ -941,3 +946,4 @@ def _augment_graph_structure(
941946 return self ._add_graph_node_and_edges_to_nodes (
942947 augmented_struct , atom_ids | fg_to_atoms_map .keys ()
943948 )
949+
0 commit comments