@@ -201,7 +201,7 @@ def _after_setup(self, **kwargs) -> None:
201201
202202 def _preprocess_smiles_for_pred (
203203 self , idx , raw_data : str | Chem .Mol , model_hparams : Optional [dict ] = None
204- ) -> dict :
204+ ) -> Optional [ dict ] :
205205 """Preprocess prediction data."""
206206 # Add dummy labels because the collate function requires them.
207207 # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
@@ -211,7 +211,7 @@ def _preprocess_smiles_for_pred(
211211 )
212212 # _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object
213213 if isinstance (result ["features" ], tuple ):
214- result ["features" ], raw_data = result ["features" ][ 0 ]
214+ result ["features" ], raw_data = result ["features" ]
215215 if result is None or result ["features" ] is None :
216216 return None
217217 for property in self .properties :
@@ -559,6 +559,8 @@ def _merge_props_into_base(
559559 geom_data = row ["features" ]
560560 if geom_data is None :
561561 return None
562+ if isinstance (geom_data , tuple ):
563+ geom_data = geom_data [0 ] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
562564 assert isinstance (geom_data , GeomData )
563565
564566 is_atom_node = geom_data .is_atom_node
0 commit comments