Skip to content

Commit 077602f

Browse files
committed
make gatv2 as default
1 parent 8d1d11d commit 077602f

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ Among all the connection schemes we evaluated, this configuration delivered the
9494
Below is the command for the model and data configuration that achieved the best classification performance using augmented graphs.
9595

9696
```bash
97-
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.config.v2=True --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0
97+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/wandb_logger.yml --model=../python-chebai-graph/configs/model/gat_aug_amgpool.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_aug_prop_as_per_node.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml --trainer.logger.init_args.name=gatv2_amg_s0
9898
```
9999

100100
### Model Hyperparameters
@@ -104,7 +104,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.lo
104104
To use a GAT-based model, choose **one** of the following configs:
105105

106106
- **Standard Pooling**: `--model=../python-chebai-graph/configs/model/gat.yml`
107-
> Standard pooling sums the learned representations from all the nodes to produce a single representation which is used for classification.
107+
> Standard pooling sums the learned representations from all the nodes to produce a single representation which is used for classification.
108108
109109
- **Atom-Augmented Node Pooling**: `--model=../python-chebai-graph/configs/model/gat_aug_aagpool.yml`
110110
> With this pooling stratergy, the learned representations are first separated into **two distinct sets**: those from atom nodes and those from all artificial nodes (both functional groups and the graph node). The representations within each set are aggregated separately (using summation) to yield two distinct single vectors. These two resulting vectors are then concatenated before being passed to the classification layer.
@@ -117,9 +117,13 @@ To use a GAT-based model, choose **one** of the following configs:
117117
- **Number of message-passing layers**: `--model.config.num_layers=5` (default: 4)
118118
- **Attention heads**: `--model.config.heads=4` (default: 8)
119119
> **Note**: The number of heads should be divisible by the output channels (or hidden channels if output channels are not specified).
120-
- **Use GATv2**: `--model.config.v2=True` (default: False)
121-
> **Note**: GATv2 addresses the limitation of static attention in GAT by introducing a dynamic attention mechanism. For further details, please refer to the [original GATv2 paper](https://arxiv.org/abs/2105.14491).
122-
120+
121+
- **To Use different GAT versions**:
122+
- **Use GAT**: `--model.config.v2=False`
123+
124+
- **Use GATv2**: `--model.config.v2=True` (__default__)
125+
> **Note**: GATv2 addresses the limitation of static attention in GAT by introducing a dynamic attention mechanism. For further details, please refer to the [original GATv2 paper](https://arxiv.org/abs/2105.14491).
126+
123127
#### **ResGated Architecture**
124128

125129
To use a ResGated GNN model, choose **one** of the following configs:

configs/model/gat.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ init_args:
99
num_layers: 4
1010
edge_dim: 7 # number of bond properties
1111
heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
12-
v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv
13-
dropout: 0
12+
v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv`
1413
n_molecule_properties: 0
1514
n_linear_layers: 1

configs/model/gat_aug_aapool.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ init_args:
99
num_layers: 4
1010
edge_dim: 11 # number of bond properties
1111
heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
12-
v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv
13-
dropout: 0
12+
v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv`
1413
n_molecule_properties: 0
1514
n_linear_layers: 1

configs/model/gat_aug_amgpool.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ init_args:
99
num_layers: 4
1010
edge_dim: 11 # number of bond properties
1111
heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given)
12-
v2: True # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv
12+
v2: True # This uses `torch_geometric.nn.conv.GATv2Conv` convolution layers, set False to use `GATConv`
1313
dropout: 0
1414
n_molecule_properties: 0
1515
n_linear_layers: 1

0 commit comments

Comments
 (0)