@@ -184,61 +184,53 @@ def _after_setup(self, **kwargs) -> None:
184184 self ._setup_properties ()
185185 super ()._after_setup (** kwargs )
186186
187- def _process_input_for_prediction (
188- self ,
189- smiles_list : list [str ],
190- model_hparams : Optional [dict ] = None ,
191- ) -> list :
192- data_df = self ._process_smiles_and_props (smiles_list )
193- data_df ["features" ] = data_df .apply (
194- lambda row : self ._merge_props_into_base (row ), axis = 1
187+ def _preprocess_smiles_for_pred (
188+ self , idx , smiles : str , model_hparams : Optional [dict ] = None
189+ ) -> dict :
190+ """Preprocess prediction data."""
191+ # Add dummy labels because the collate function requires them.
192+ # Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
193+ # which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
194+ result = self .reader .to_data (
195+ {"id" : f"smiles_{ idx } " , "features" : smiles , "labels" : [1 , 2 ]}
196+ )
197+ if result is None or result ["features" ] is None :
198+ return None
199+ for property in self .properties :
200+ property .encoder .eval = True
201+ property_value = self .reader .read_property (smiles , property )
202+ if property_value is None or len (property_value ) == 0 :
203+ encoded_value = None
204+ else :
205+ encoded_value = torch .stack (
206+ [property .encoder .encode (v ) for v in property_value ]
207+ )
208+ if len (encoded_value .shape ) == 3 :
209+ encoded_value = encoded_value .squeeze (0 )
210+ result [property .name ] = encoded_value
211+
212+ result ["features" ] = self ._prediction_merge_props_into_base_wrapper (
213+ result , model_hparams
195214 )
196215
197216 # apply transformation, e.g. masking for pretraining task
198217 if self .transform is not None :
199- data_df ["features" ] = data_df ["features" ]. apply ( self . transform )
218+ result ["features" ] = self . transform ( result ["features" ])
200219
201- return data_df . to_dict ( "records" )
220+ return result
202221
203- def _process_smiles_and_props (self , smiles_list : list [str ]) -> pd .DataFrame :
222+ def _prediction_merge_props_into_base_wrapper (
223+ self , row : pd .Series | dict , model_hparams : Optional [dict ] = None
224+ ) -> GeomData :
204225 """
205- Process SMILES strings and compute molecular properties.
226+ Wrapper to merge properties into base features for prediction.
227+
228+ Args:
229+ row: A dictionary or pd.Series containing 'features' and encoded properties.
230+ Returns:
231+ A GeomData object with merged features.
206232 """
207- data = [
208- self .reader .to_data (
209- {"ident" : f"smiles_{ idx } " , "features" : smiles , "labels" : None }
210- )
211- for idx , smiles in enumerate (smiles_list )
212- ]
213- # element of data is a dict with 'id' and 'features' (GeomData)
214- # GeomData has only edge_index filled but node and edges features are empty.
215-
216- assert len (data ) == len (smiles_list ), "Data length mismatch."
217- data_df = pd .DataFrame (data )
218-
219- props : list [dict ] = []
220- for data_row in data_df .itertuples (index = True ):
221- row_prop_dict : dict = {}
222- for property in self .properties :
223- property .encoder .eval = True
224- property_value = self .reader .read_property (
225- smiles_list [data_row .Index ], property
226- )
227- if property_value is None or len (property_value ) == 0 :
228- encoded_value = None
229- else :
230- encoded_value = torch .stack (
231- [property .encoder .encode (v ) for v in property_value ]
232- )
233- if len (encoded_value .shape ) == 3 :
234- encoded_value = encoded_value .squeeze (0 )
235- row_prop_dict [property .name ] = encoded_value
236- row_prop_dict ["ident" ] = data_row .ident
237- props .append (row_prop_dict )
238-
239- property_df = pd .DataFrame (props )
240- data_df = data_df .merge (property_df , on = "ident" , how = "left" )
241- return data_df
233+ return self ._merge_props_into_base (row )
242234
243235
244236class GraphPropertiesMixIn (DataPropertiesSetter , ABC ):
@@ -276,7 +268,7 @@ def __init__(
276268 f"Data module uses these properties (ordered): { ', ' .join ([str (p ) for p in self .properties ])} "
277269 )
278270
279- def _merge_props_into_base (self , row : pd .Series ) -> GeomData :
271+ def _merge_props_into_base (self , row : pd .Series | dict ) -> GeomData :
280272 """
281273 Merge encoded molecular properties into the GeomData object.
282274
@@ -544,6 +536,8 @@ def _merge_props_into_base(
544536 A GeomData object with merged features.
545537 """
546538 geom_data = row ["features" ]
539+ if geom_data is None :
540+ return None
547541 assert isinstance (geom_data , GeomData )
548542
549543 is_atom_node = geom_data .is_atom_node
@@ -627,11 +621,17 @@ def _merge_props_into_base(
627621 is_graph_node = is_graph_node ,
628622 )
629623
630- def _process_input_for_prediction (
631- self ,
632- smiles_list : list [str ],
633- model_hparams : Optional [dict ] = None ,
634- ) -> list :
624+ def _prediction_merge_props_into_base_wrapper (
625+ self , row : pd .Series | dict , model_hparams : Optional [dict ] = None
626+ ) -> GeomData :
627+ """
628+ Wrapper to merge properties into base features for prediction.
629+
630+ Args:
631+ row: A dictionary or pd.Series containing 'features' and encoded properties.
632+ Returns:
633+ A GeomData object with merged features.
634+ """
635635 if (
636636 model_hparams is None
637637 or "in_channels" not in model_hparams ["config" ]
@@ -641,21 +641,8 @@ def _process_input_for_prediction(
641641 f"model_hparams must be provided for data class: { self .__class__ .__name__ } "
642642 f" which should contain 'in_channels' key with valid value in 'config' dictionary."
643643 )
644-
645644 max_len_node_properties = int (model_hparams ["config" ]["in_channels" ])
646- # Determine max_len_node_properties based on in_channels
647-
648- data_df = self ._process_smiles_and_props (smiles_list )
649- data_df ["features" ] = data_df .apply (
650- lambda row : self ._merge_props_into_base (row , max_len_node_properties ),
651- axis = 1 ,
652- )
653-
654- # apply transformation, e.g. masking for pretraining task
655- if self .transform is not None :
656- data_df ["features" ] = data_df ["features" ].apply (self .transform )
657-
658- return data_df .to_dict ("records" )
645+ return self ._merge_props_into_base (row , max_len_node_properties )
659646
660647
661648class ChEBI50_StaticGNI (DataPropertiesSetter , ChEBIOver50 ):
0 commit comments