Skip to content

Commit d2bbad5

Browse files
committed
fix assertion error
1 parent e78183e commit d2bbad5

File tree

1 file changed

+4
-2
lines changed
  • chebai_graph/preprocessing/datasets

1 file changed

+4
-2
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _after_setup(self, **kwargs) -> None:
201201

202202
def _preprocess_smiles_for_pred(
203203
self, idx, raw_data: str | Chem.Mol, model_hparams: Optional[dict] = None
204-
) -> dict:
204+
) -> Optional[dict]:
205205
"""Preprocess prediction data."""
206206
# Add dummy labels because the collate function requires them.
207207
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
@@ -211,7 +211,7 @@ def _preprocess_smiles_for_pred(
211211
)
212212
# _read_data can return an updated version of the input data (e.g. augmented molecule dict) along with the GeomData object
213213
if isinstance(result["features"], tuple):
214-
result["features"], raw_data = result["features"][0]
214+
result["features"], raw_data = result["features"]
215215
if result is None or result["features"] is None:
216216
return None
217217
for property in self.properties:
@@ -559,6 +559,8 @@ def _merge_props_into_base(
559559
geom_data = row["features"]
560560
if geom_data is None:
561561
return None
562+
if isinstance(geom_data, tuple):
563+
geom_data = geom_data[0] # ignore additional returned data from _read_data (e.g. augmented molecule dict)
562564
assert isinstance(geom_data, GeomData)
563565

564566
is_atom_node = geom_data.is_atom_node

0 commit comments

Comments
 (0)