Skip to content

Commit 8e9d225

Browse files
committed
adapt code for new logic to handle none returns
1 parent a11ba50 commit 8e9d225

File tree

1 file changed

+54
-67
lines changed
  • chebai_graph/preprocessing/datasets

1 file changed

+54
-67
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 54 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -184,61 +184,53 @@ def _after_setup(self, **kwargs) -> None:
184184
self._setup_properties()
185185
super()._after_setup(**kwargs)
186186

187-
def _process_input_for_prediction(
188-
self,
189-
smiles_list: list[str],
190-
model_hparams: Optional[dict] = None,
191-
) -> list:
192-
data_df = self._process_smiles_and_props(smiles_list)
193-
data_df["features"] = data_df.apply(
194-
lambda row: self._merge_props_into_base(row), axis=1
187+
def _preprocess_smiles_for_pred(
188+
self, idx, smiles: str, model_hparams: Optional[dict] = None
189+
) -> dict:
190+
"""Preprocess prediction data."""
191+
# Add dummy labels because the collate function requires them.
192+
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
193+
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
194+
result = self.reader.to_data(
195+
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
196+
)
197+
if result is None or result["features"] is None:
198+
return None
199+
for property in self.properties:
200+
property.encoder.eval = True
201+
property_value = self.reader.read_property(smiles, property)
202+
if property_value is None or len(property_value) == 0:
203+
encoded_value = None
204+
else:
205+
encoded_value = torch.stack(
206+
[property.encoder.encode(v) for v in property_value]
207+
)
208+
if len(encoded_value.shape) == 3:
209+
encoded_value = encoded_value.squeeze(0)
210+
result[property.name] = encoded_value
211+
212+
result["features"] = self._prediction_merge_props_into_base_wrapper(
213+
result, model_hparams
195214
)
196215

197216
# apply transformation, e.g. masking for pretraining task
198217
if self.transform is not None:
199-
data_df["features"] = data_df["features"].apply(self.transform)
218+
result["features"] = self.transform(result["features"])
200219

201-
return data_df.to_dict("records")
220+
return result
202221

203-
def _process_smiles_and_props(self, smiles_list: list[str]) -> pd.DataFrame:
222+
def _prediction_merge_props_into_base_wrapper(
223+
self, row: pd.Series | dict, model_hparams: Optional[dict] = None
224+
) -> GeomData:
204225
"""
205-
Process SMILES strings and compute molecular properties.
226+
Wrapper to merge properties into base features for prediction.
227+
228+
Args:
229+
row: A dictionary or pd.Series containing 'features' and encoded properties.
230+
Returns:
231+
A GeomData object with merged features.
206232
"""
207-
data = [
208-
self.reader.to_data(
209-
{"ident": f"smiles_{idx}", "features": smiles, "labels": None}
210-
)
211-
for idx, smiles in enumerate(smiles_list)
212-
]
213-
# element of data is a dict with 'id' and 'features' (GeomData)
214-
# GeomData has only edge_index filled but node and edges features are empty.
215-
216-
assert len(data) == len(smiles_list), "Data length mismatch."
217-
data_df = pd.DataFrame(data)
218-
219-
props: list[dict] = []
220-
for data_row in data_df.itertuples(index=True):
221-
row_prop_dict: dict = {}
222-
for property in self.properties:
223-
property.encoder.eval = True
224-
property_value = self.reader.read_property(
225-
smiles_list[data_row.Index], property
226-
)
227-
if property_value is None or len(property_value) == 0:
228-
encoded_value = None
229-
else:
230-
encoded_value = torch.stack(
231-
[property.encoder.encode(v) for v in property_value]
232-
)
233-
if len(encoded_value.shape) == 3:
234-
encoded_value = encoded_value.squeeze(0)
235-
row_prop_dict[property.name] = encoded_value
236-
row_prop_dict["ident"] = data_row.ident
237-
props.append(row_prop_dict)
238-
239-
property_df = pd.DataFrame(props)
240-
data_df = data_df.merge(property_df, on="ident", how="left")
241-
return data_df
233+
return self._merge_props_into_base(row)
242234

243235

244236
class GraphPropertiesMixIn(DataPropertiesSetter, ABC):
@@ -276,7 +268,7 @@ def __init__(
276268
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}"
277269
)
278270

279-
def _merge_props_into_base(self, row: pd.Series) -> GeomData:
271+
def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData:
280272
"""
281273
Merge encoded molecular properties into the GeomData object.
282274
@@ -544,6 +536,8 @@ def _merge_props_into_base(
544536
A GeomData object with merged features.
545537
"""
546538
geom_data = row["features"]
539+
if geom_data is None:
540+
return None
547541
assert isinstance(geom_data, GeomData)
548542

549543
is_atom_node = geom_data.is_atom_node
@@ -627,11 +621,17 @@ def _merge_props_into_base(
627621
is_graph_node=is_graph_node,
628622
)
629623

630-
def _process_input_for_prediction(
631-
self,
632-
smiles_list: list[str],
633-
model_hparams: Optional[dict] = None,
634-
) -> list:
624+
def _prediction_merge_props_into_base_wrapper(
625+
self, row: pd.Series | dict, model_hparams: Optional[dict] = None
626+
) -> GeomData:
627+
"""
628+
Wrapper to merge properties into base features for prediction.
629+
630+
Args:
631+
row: A dictionary or pd.Series containing 'features' and encoded properties.
632+
Returns:
633+
A GeomData object with merged features.
634+
"""
635635
if (
636636
model_hparams is None
637637
or "in_channels" not in model_hparams["config"]
@@ -641,21 +641,8 @@ def _process_input_for_prediction(
641641
f"model_hparams must be provided for data class: {self.__class__.__name__}"
642642
f" which should contain 'in_channels' key with valid value in 'config' dictionary."
643643
)
644-
645644
max_len_node_properties = int(model_hparams["config"]["in_channels"])
646-
# Determine max_len_node_properties based on in_channels
647-
648-
data_df = self._process_smiles_and_props(smiles_list)
649-
data_df["features"] = data_df.apply(
650-
lambda row: self._merge_props_into_base(row, max_len_node_properties),
651-
axis=1,
652-
)
653-
654-
# apply transformation, e.g. masking for pretraining task
655-
if self.transform is not None:
656-
data_df["features"] = data_df["features"].apply(self.transform)
657-
658-
return data_df.to_dict("records")
645+
return self._merge_props_into_base(row, max_len_node_properties)
659646

660647

661648
class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50):

0 commit comments

Comments
 (0)