Skip to content

Commit 63216b4

Browse files
authored
Merge pull request #31 from ChEB-AI/fix/predict_pipeline
Prediction functional for Graphs
2 parents 301b7c6 + 2406f95 commit 63216b4

File tree

5 files changed

+90
-15
lines changed

5 files changed

+90
-15
lines changed

README.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ The dataset has a customizable list of properties for atoms, bonds and molecules
7373
The list can be found in the `configs/data/chebi50_graph_properties.yml` file.
7474

7575
```bash
76-
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml
76+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce_weighted.yml
7777
```
7878

7979
## Augmented Graphs
@@ -94,7 +94,7 @@ Among all the connection schemes we evaluated, this configuration delivered the
9494
Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs.
9595

9696
```bash
97-
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0
97+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce_weighted.yml --trainer.logger.init_args.name=gatv2_amg_s0
9898
```
9999

100100
### Model Hyperparameters
@@ -104,7 +104,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.lo
104104
To use a GAT-based model, choose **one** of the following configs:
105105

106106
- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/gat.yml`
107-
> Standard pooling sums the learned representations from all the nodes to produce a single representation which is used for classification.
107+
> Standard pooling sums the learned representations from all the nodes to produce a single representation which is used for classification.
108108
109109
- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml`
110110
> With this pooling stratergy, the learned representations are first separated into **two distinct sets**: those from atom nodes and those from all artificial nodes (both functional groups and the graph node). The representations within each set are aggregated separately (using summation) to yield two distinct single vectors. These two resulting vectors are then concatenated before being passed to the classification layer.
@@ -117,9 +117,13 @@ To use a GAT-based model, choose **one** of the following configs:
117117
- **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4)
118118
- **Attention heads**: `--model.config.heads=4` (default: 8)
119119
> **Note**: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified).
120-
- **Use GATv2**: `--model.config.v2=True` (default: False)
121-
> **Note**: GATv2 addresses the limitation of static attention in GAT by introducing a dynamic attention mechanism. For further details, please refer to the [original GATv2 paper](https://arxiv.org/abs/2105.14491).
122-
120+
121+
- **To Use different GAT versions**:
122+
- **Use GAT**: `--model.config.v2=False`
123+
124+
- **Use GATv2**: `--model.config.v2=True` (__default__)
125+
> **Note**: GATv2 addresses the limitation of static attention in GAT by introducing a dynamic attention mechanism. For further details, please refer to the [original GATv2 paper](https://arxiv.org/abs/2105.14491).
126+
123127
#### **ResGated Architecture**
124128

125129
To use a ResGated GNN model, choose **one** of the following configs:
@@ -142,7 +146,7 @@ These can be used for both GAT and ResGated architectures:
142146
In this type of node initialization, the node features (and/or edge features) of the given molecular graph are initialized only once during dataset creation with the given initialization scheme.
143147

144148
```bash
145-
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0
149+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.pad_node_features=45 --data.pad_edge_features=4 --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce_weighted.yml --trainer.logger.init_args.name=gni_res_props+zeros_s0
146150
```
147151

148152
In the above command, for each node we use the 158 node features (corresponding the node properties defined in `chebi50_graph_properties.yml`) which are retrieved from RDKit and additional 45 additional features (specified by `--data.pad_node_features=45`) drawn from a normal distribution (default).
@@ -184,5 +188,5 @@ If all features should be initialized from the given distribution, remove the co
184188
Please find below the command for a typical dynamic node initialization:
185189

186190
```bash
187-
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.config.complete_randomness=False --model.config.pad_node_features=45 --model.config.pad_edge_features=4 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gni_dres_props+rand_s0
191+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/resgated_dynamic_gni.yml --model.config.in_channels=203 --model.config.edge_dim=11 --model.config.complete_randomness=False --model.config.pad_node_features=45 --model.config.pad_edge_features=4 --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --data.init_args.persistent_workers=False --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce_weighted.yml --trainer.logger.init_args.name=gni_dres_props+rand_s0
188192
```

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
properties = self._sort_properties(properties)
7878
else:
7979
properties = []
80-
self.properties = properties
80+
self.properties: list[MolecularProperty] = properties
8181
assert isinstance(self.properties, list) and all(
8282
isinstance(p, MolecularProperty) for p in self.properties
8383
)
@@ -184,6 +184,54 @@ def _after_setup(self, **kwargs) -> None:
184184
self._setup_properties()
185185
super()._after_setup(**kwargs)
186186

187+
def _preprocess_smiles_for_pred(
188+
self, idx, smiles: str, model_hparams: Optional[dict] = None
189+
) -> dict:
190+
"""Preprocess prediction data."""
191+
# Add dummy labels because the collate function requires them.
192+
# Note: If labels are set to `None`, the collator will insert a `non_null_labels` entry into `loss_kwargs`,
193+
# which later causes `_get_prediction_and_labels` method in the prediction pipeline to treat the data as empty.
194+
result = self.reader.to_data(
195+
{"id": f"smiles_{idx}", "features": smiles, "labels": [1, 2]}
196+
)
197+
if result is None or result["features"] is None:
198+
return None
199+
for property in self.properties:
200+
property.encoder.eval = True
201+
property_value = self.reader.read_property(smiles, property)
202+
if property_value is None or len(property_value) == 0:
203+
encoded_value = None
204+
else:
205+
encoded_value = torch.stack(
206+
[property.encoder.encode(v) for v in property_value]
207+
)
208+
if len(encoded_value.shape) == 3:
209+
encoded_value = encoded_value.squeeze(0)
210+
result[property.name] = encoded_value
211+
212+
result["features"] = self._prediction_merge_props_into_base_wrapper(
213+
result, model_hparams
214+
)
215+
216+
# apply transformation, e.g. masking for pretraining task
217+
if self.transform is not None:
218+
result["features"] = self.transform(result["features"])
219+
220+
return result
221+
222+
def _prediction_merge_props_into_base_wrapper(
223+
self, row: pd.Series | dict, model_hparams: Optional[dict] = None
224+
) -> GeomData:
225+
"""
226+
Wrapper to merge properties into base features for prediction.
227+
228+
Args:
229+
row: A dictionary or pd.Series containing 'features' and encoded properties.
230+
Returns:
231+
A GeomData object with merged features.
232+
"""
233+
return self._merge_props_into_base(row)
234+
187235

188236
class GraphPropertiesMixIn(DataPropertiesSetter, ABC):
189237
def __init__(
@@ -220,7 +268,7 @@ def __init__(
220268
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}"
221269
)
222270

223-
def _merge_props_into_base(self, row: pd.Series) -> GeomData:
271+
def _merge_props_into_base(self, row: pd.Series | dict) -> GeomData:
224272
"""
225273
Merge encoded molecular properties into the GeomData object.
226274
@@ -488,6 +536,8 @@ def _merge_props_into_base(
488536
A GeomData object with merged features.
489537
"""
490538
geom_data = row["features"]
539+
if geom_data is None:
540+
return None
491541
assert isinstance(geom_data, GeomData)
492542

493543
is_atom_node = geom_data.is_atom_node
@@ -571,6 +621,29 @@ def _merge_props_into_base(
571621
is_graph_node=is_graph_node,
572622
)
573623

624+
def _prediction_merge_props_into_base_wrapper(
625+
self, row: pd.Series | dict, model_hparams: Optional[dict] = None
626+
) -> GeomData:
627+
"""
628+
Wrapper to merge properties into base features for prediction.
629+
630+
Args:
631+
row: A dictionary or pd.Series containing 'features' and encoded properties.
632+
Returns:
633+
A GeomData object with merged features.
634+
"""
635+
if (
636+
model_hparams is None
637+
or "in_channels" not in model_hparams["config"]
638+
or model_hparams["config"]["in_channels"] is None
639+
):
640+
raise ValueError(
641+
f"model_hparams must be provided for data class: {self.__class__.__name__}"
642+
f" which should contain 'in_channels' key with valid value in 'config' dictionary."
643+
)
644+
max_len_node_properties = int(model_hparams["config"]["in_channels"])
645+
return self._merge_props_into_base(row, max_len_node_properties)
646+
574647

575648
class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50):
576649
READER = RandomFeatureInitializationReader

configs/model/gat.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ init_args:
99
num_layers: 4
1010
edge_dim: 7 # number of bond properties
1111
heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
12-
v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv
13-
dropout: 0
12+
v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv`
1413
n_molecule_properties: 0
1514
n_linear_layers: 1

configs/model/gat_aug_aapool.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ init_args:
99
num_layers: 4
1010
edge_dim: 11 # number of bond properties
1111
heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
12-
v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv
13-
dropout: 0
12+
v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv`
1413
n_molecule_properties: 0
1514
n_linear_layers: 1

configs/model/gat_aug_amgpool.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ init_args:
99
num_layers: 4
1010
edge_dim: 11 # number of bond properties
1111
heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
12-
v2: True # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv
12+
v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv`
1313
dropout: 0
1414
n_molecule_properties: 0
1515
n_linear_layers: 1

0 commit comments

Comments
 (0)