@@ -59,16 +59,16 @@ def name(cls) -> str:
5959 """
6060 return f"{ cls .__name__ } " .lower ()
6161
62- def _read_data (self , raw_data : str | Chem .Mol ) -> GeomData | None :
62+ def _read_data (self , raw_data : str | Chem .Mol ) -> tuple [ GeomData , dict ] | None :
6363 """
6464 Reads and augments molecular data from a SMILES string.
6565
6666 Args:
6767 raw_data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule.
6868
6969 Returns:
70- GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges,
71- or None if parsing or augmentation fails.
70+ tuple[ GeomData, dict] | None: A tuple containing a PyTorch Geometric Data object with augmented nodes and edges,
71+ and a dictionary of augmented molecule data, or None if parsing or augmentation fails.
7272
7373 Raises:
7474 RuntimeError: If an unexpected error occurs during graph augmentation.
@@ -124,12 +124,12 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
124124 NUM_ATOM_NODES = augmented_molecule ["nodes" ]["atom_nodes" ].GetNumAtoms ()
125125 is_atom_mask [:NUM_ATOM_NODES ] = True
126126
127- return GeomData (
127+ return ( GeomData (
128128 x = x ,
129129 edge_index = edge_index ,
130130 edge_attr = edge_attr ,
131131 is_atom_node = is_atom_mask ,
132- )
132+ ), augmented_molecule )
133133
134134 def _smiles_to_mol (self , smiles : str ) -> Chem .Mol | None :
135135 """
@@ -285,32 +285,35 @@ def on_finish(self) -> None:
285285 )
286286 self .mol_object_buffer = {}
287287
288- def read_property (self , data : str | Chem .Mol , property : MolecularProperty ) -> list | None :
288+ def read_property (self , raw_data : str | Chem .Mol | dict , property : MolecularProperty ) -> list | None :
289289 """
290290 Reads a specific property from a molecule represented by a SMILES string.
291291
292292 Args:
293- data (str | Chem.Mol): SMILES string or RDKit molecule object representing the molecule.
293+ raw_data (str | Chem.Mol | dict ): SMILES string, RDKit molecule object, or dictionary representation of a molecule.
294294 property (MolecularProperty): Molecular property object for which the value needs to be extracted.
295295
296296 Returns:
297297 list | None: Property values if molecule parsing is successful, else None.
298298 """
299- if isinstance (data , Chem . Mol ):
300- mol = data
299+ if isinstance (raw_data , dict ):
300+ augmented_mol = raw_data
301301 else :
302- smiles = data
303- if smiles in self .mol_object_buffer :
304- return property .get_property_value (self .mol_object_buffer [smiles ])
305- mol = self ._smiles_to_mol (smiles )
306- if mol is None :
307- return None
302+ if isinstance (raw_data , Chem .Mol ):
303+ mol = raw_data
304+ else :
305+ smiles = raw_data
306+ if smiles in self .mol_object_buffer :
307+ return property .get_property_value (self .mol_object_buffer [smiles ])
308+ mol = self ._smiles_to_mol (smiles )
309+ if mol is None :
310+ return None
308311
309- returned_result = self ._create_augmented_graph (mol )
310- if returned_result is None :
311- return None
312+ returned_result = self ._create_augmented_graph (mol )
313+ if returned_result is None :
314+ return None
312315
313- _ , augmented_mol = returned_result
316+ _ , augmented_mol = returned_result
314317 return property .get_property_value (augmented_mol )
315318
316319
@@ -680,14 +683,14 @@ def _read_data(self, raw_data: str | Chem.Mol) -> GeomData | None:
680683 Returns:
681684 Data | None: Geometric data object with is_graph_node annotation.
682685 """
683- geom_data = super ()._read_data (raw_data )
686+ geom_data , augmented_mol = super ()._read_data (raw_data )
684687 if geom_data is None :
685688 return None
686689 NUM_NODES = geom_data .x .shape [0 ]
687690 is_graph_node = torch .zeros (NUM_NODES , dtype = torch .bool )
688691 is_graph_node [- 1 ] = True
689692 geom_data .is_graph_node = is_graph_node
690- return geom_data
693+ return ( geom_data , augmented_mol )
691694
692695 def _add_graph_node_and_edges_to_nodes (
693696 self ,
0 commit comments