Skip to content

Commit be7278d

Browse files
committed
avoid redundant property calculation
1 parent 66c9659 commit be7278d

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Optional
66

77
import pandas as pd
8+
from chebai_graph.preprocessing.reader.augmented_reader import _AugmentorReader
89
import torch
910
import tqdm
1011
from chebai.preprocessing.datasets.chebi import (
@@ -126,15 +127,22 @@ def enc_if_not_none(encode, value):
126127
if value is not None and len(value) > 0
127128
else None
128129
)
130+
131+
# augment molecule graph if possible (this would also happen for the properties if needed, but this avoids redundancy)
132+
if isinstance(self.reader, _AugmentorReader):
133+
returned_results = [self._create_augmented_graph(mol) for mol in features]
134+
mols = [augmented_mol[1] for augmented_mol in returned_results if augmented_mol is not None]
135+
else:
136+
mols = features
129137

130138
for property in self.properties:
131139
if not os.path.isfile(self.get_property_path(property)):
132140
rank_zero_info(f"Processing property {property.name}")
133141
# read all property values first, then encode
134142
rank_zero_info(f"\tReading property values of {property.name}...")
135143
property_values = [
136-
self.reader.read_property(feat, property)
137-
for feat in tqdm.tqdm(features)
144+
self.reader.read_property(mol, property)
145+
for mol in tqdm.tqdm(mols)
138146
]
139147
rank_zero_info(f"\tEncoding property values of {property.name}...")
140148
property.encoder.on_start(property_values=property_values)

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,12 @@ def read_property(self, raw_data: str | Chem.Mol | dict, property: MolecularProp
308308
mol = self._smiles_to_mol(smiles)
309309
if mol is None:
310310
return None
311-
312-
returned_result = self._create_augmented_graph(mol)
311+
try:
312+
returned_result = self._create_augmented_graph(mol)
313+
except Exception as e:
314+
print(f"Failed to construct augmented graph, Error: {e}")
315+
self.f_cnt_for_aug_graph += 1
316+
return None
313317
if returned_result is None:
314318
return None
315319

0 commit comments

Comments
 (0)