Skip to content

Commit 64e93ab

Browse files
authored
Merge branch 'dev' into feature/chebi2.0-adaption
2 parents e713104 + 5243e02 commit 64e93ab

File tree

19 files changed

+764
-955
lines changed

19 files changed

+764
-955
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ chebai.egg-info
175175
lightning_logs
176176
logs
177177
.isort.cfg
178-
/.vscode
178+
/.vscode/launch.json
179179

180180
*.out
181181
*.err

.vscode/extensions.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"recommendations": [
3+
"ms-python.python",
4+
"ms-python.vscode-pylance",
5+
"charliermarsh.ruff",
6+
"usernamehw.errorlens"
7+
],
8+
"unwantedRecommendations": [
9+
"ms-python.vscode-python2"
10+
]
11+
}

.vscode/settings.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"python.testing.unittestArgs": [
3+
"-v",
4+
"-s",
5+
"./tests",
6+
"-p",
7+
"test*.py"
8+
],
9+
"python.testing.pytestEnabled": false,
10+
"python.testing.unittestEnabled": true,
11+
"python.analysis.typeCheckingMode": "basic",
12+
"editor.formatOnSave": true,
13+
"[python]": {
14+
"editor.defaultFormatter": "charliermarsh.ruff"
15+
}
16+
}

LICENSE

Lines changed: 21 additions & 661 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf
6363
```
6464
A command with additional options may look like this:
6565
```
66-
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
66+
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_weighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
6767
```
6868

6969
### Fine-tuning for classification tasks, e.g. Toxicity prediction
@@ -78,11 +78,16 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con
7878

7979
### Predicting classes given SMILES strings
8080
```
81-
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
81+
python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] --smiles_file_path=[path-to-file-containing-smiles] [--save_to=[path-to-output]]
8282
```
83-
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
84-
one row for each SMILES string and one column for each class.
85-
The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.
83+
84+
* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`).
85+
86+
* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line.
87+
88+
* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`.
89+
90+
> **Note**: Newly created checkpoints after PR #148 must be used for this prediction pipeline. The list of ChEBI classes (classification labels) used during training is stored in new checkpoints, which are required.
8691
8792
## Evaluation
8893

@@ -96,7 +101,7 @@ An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
96101
Alternatively, you can evaluate the model via the CLI:
97102

98103
```bash
99-
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
104+
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce_weighted.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
100105
```
101106

102107
> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once.

chebai/cli.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]):
5959
apply_on="instantiate",
6060
)
6161

62+
parser.link_arguments(
63+
"data.classes_txt_file_path",
64+
"model.init_args.classes_txt_file_path",
65+
apply_on="instantiate",
66+
)
67+
6268
for kind in ("train", "val", "test"):
6369
for average in (
6470
"micro-f1",
@@ -111,8 +117,6 @@ def subcommands() -> Dict[str, Set[str]]:
111117
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
112118
"validate": {"model", "dataloaders", "datamodule"},
113119
"test": {"model", "dataloaders", "datamodule"},
114-
"predict": {"model", "dataloaders", "datamodule"},
115-
"predict_from_file": {"model"},
116120
}
117121

118122

chebai/models/base.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,16 @@ def __init__(
4040
pass_loss_kwargs: bool = True,
4141
optimizer_kwargs: Optional[Dict[str, Any]] = None,
4242
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
43+
classes_txt_file_path: Optional[str] = None,
4344
**kwargs,
4445
):
4546
super().__init__(**kwargs)
4647
# super().__init__()
4748
if exclude_hyperparameter_logging is None:
4849
exclude_hyperparameter_logging = tuple()
4950
self.criterion = criterion
50-
assert out_dim is not None, "out_dim must be specified"
51-
assert input_dim is not None, "input_dim must be specified"
51+
assert out_dim is not None and out_dim > 0, "out_dim must be specified"
52+
assert input_dim is not None and input_dim > 0, "input_dim must be specified"
5253
self.out_dim = out_dim
5354
self.input_dim = input_dim
5455
print(
@@ -62,6 +63,7 @@ def __init__(
6263
"train_metrics",
6364
"val_metrics",
6465
"test_metrics",
66+
"classes_txt_file_path",
6567
*exclude_hyperparameter_logging,
6668
]
6769
)
@@ -78,6 +80,23 @@ def __init__(
7880
self.test_metrics = test_metrics
7981
self.pass_loss_kwargs = pass_loss_kwargs
8082

83+
self.classes_txt_file_path = classes_txt_file_path
84+
85+
# During prediction `classes_txt_file_path` is set to None
86+
if classes_txt_file_path is not None:
87+
with open(classes_txt_file_path, "r") as f:
88+
self.labels_list = [cls.strip() for cls in f.readlines()]
89+
assert len(self.labels_list) > 0, "Class labels list is empty."
90+
assert len(self.labels_list) == out_dim, (
91+
f"Number of class labels ({len(self.labels_list)}) does not match "
92+
f"the model output dimension ({out_dim})."
93+
)
94+
95+
def on_save_checkpoint(self, checkpoint):
96+
if self.classes_txt_file_path is not None:
97+
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere
98+
checkpoint["classification_labels"] = self.labels_list
99+
81100
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
82101
# avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a
83102
# different loss)
@@ -100,7 +119,7 @@ def __init_subclass__(cls, **kwargs):
100119

101120
def _get_prediction_and_labels(
102121
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
103-
) -> (torch.Tensor, torch.Tensor):
122+
) -> tuple[torch.Tensor, torch.Tensor]:
104123
"""
105124
Gets the predictions and labels from the model output.
106125
@@ -151,7 +170,7 @@ def _process_for_loss(
151170
model_output: torch.Tensor,
152171
labels: torch.Tensor,
153172
loss_kwargs: Dict[str, Any],
154-
) -> (torch.Tensor, torch.Tensor, Dict[str, Any]):
173+
) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
155174
"""
156175
Processes the data for loss computation.
157176
@@ -237,7 +256,7 @@ def predict_step(
237256
Returns:
238257
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
239258
"""
240-
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
259+
return self._execute(batch, batch_idx, log=False)
241260

242261
def _execute(
243262
self,

chebai/models/electra.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any
203203
)
204204
* CLS_TOKEN
205205
)
206+
model_kwargs["output_attentions"] = True
206207
return dict(
207208
features=torch.cat((cls_tokens, batch.x), dim=1),
208209
labels=batch.y,

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 2 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -4373,181 +4373,5 @@ b
43734373
[CaH2]
43744374
[NH3]
43754375
[OH2]
4376-
[1*]
4377-
[2*]
4378-
[3*]
4379-
[4*]
4380-
[5*]
4381-
[6*]
4382-
[7*]
4383-
[8*]
4384-
[9*]
4385-
[3He+]
4386-
[12C+4]
4387-
[16O+6]
4388-
[11B-3]
4389-
[11B+3]
4390-
[31P+3]
4391-
[31P+5]
4392-
[34S+2]
4393-
[34S+4]
4394-
[34S+6]
4395-
[55Mn+2]
4396-
[55Mn+4]
4397-
[55Mn+7]
4398-
[57Fe+3]
4399-
[59Co+2]
4400-
[75As-3]
4401-
[98Mo+3]
4402-
[98Mo+6]
4403-
[Cl:1]
4404-
[c:2]
4405-
[n:3]
4406-
[c:4]
4407-
[cH:5]
4408-
[cH:6]
4409-
[cH:7]
4410-
[cH:8]
4411-
[c:9]
4412-
[cH:10]
4413-
[c:11]
4414-
[CH2:12]
4415-
[O:13]
4416-
[c:14]
4417-
[cH:15]
4418-
[cH:16]
4419-
[cH:17]
4420-
[c:18]
4421-
[cH:19]
4422-
[cH:20]
4423-
[cH:21]
4424-
[n:22]
4425-
[c:23]
4426-
[CH3:1]
4427-
[C:2]
4428-
[O:3]
4429-
[O:4]
4430-
[c:5]
4431-
[cH:9]
4432-
[c:10]
4433-
[C:11]
4434-
[O:12]
4435-
[CH2:14]
4436-
[C:15]
4437-
[O:16]
4438-
[NH:17]
4439-
[CH2:18]
4440-
[CH:19]
4441-
[CH2:20]
4442-
[N:21]
4443-
[c:25]
4444-
[cH:26]
4445-
[c:27]
4446-
[F:37]
4447-
[c:28]
4448-
[N:31]
4449-
[CH2:32]
4450-
[CH2:33]
4451-
[O:34]
4452-
[CH2:35]
4453-
[CH2:36]
4454-
[cH:29]
4455-
[cH:30]
4456-
[C:22]
4457-
[O:23]
4458-
[O:24]
4459-
[NaH2-]
4460-
[KH2-]
4461-
[LiH2-]
4462-
[BH2-3]
4463-
[BeH3-]
4464-
[RbH2-]
4465-
[FrH2-]
4466-
[AlH-2]
4467-
[CsH2-]
4468-
[P@TB5]
4469-
[Ru@OH14]
4470-
[Ru@OH15+2]
4471-
[Ru@OH16]
4472-
[Ru@OH23+2]
4473-
[Ru@OH4]
4474-
[*:0]
4475-
[1*:0]
4476-
[2*:0]
4477-
[3*:0]
4478-
[224RaH2]
4479-
[226RaH2]
4480-
[228RaH2]
4481-
[H:24]
4482-
[c:6]
4483-
[H:25]
4484-
[c:7]
4485-
[H:26]
4486-
[c:8]
4487-
[H:27]
4488-
[H:28]
4489-
[C:12]
4490-
[c:15]
4491-
[H:31]
4492-
[c:16]
4493-
[H:32]
4494-
[c:17]
4495-
[H:33]
4496-
[c:19]
4497-
[H:34]
4498-
[c:20]
4499-
[H:35]
4500-
[c:21]
4501-
[H:36]
4502-
[H:29]
4503-
[H:30]
4504-
[C:1]
4505-
[H:41]
4506-
[H:42]
4507-
[H:43]
4508-
[H:44]
4509-
[C:14]
4510-
[N:17]
4511-
[C:18]
4512-
[C:19]
4513-
[H:50]
4514-
[C:20]
4515-
[H:51]
4516-
[H:52]
4517-
[c:26]
4518-
[H:53]
4519-
[C:32]
4520-
[H:56]
4521-
[H:57]
4522-
[C:33]
4523-
[H:58]
4524-
[H:59]
4525-
[C:35]
4526-
[H:60]
4527-
[H:61]
4528-
[C:36]
4529-
[H:62]
4530-
[H:63]
4531-
[c:29]
4532-
[H:54]
4533-
[c:30]
4534-
[H:55]
4535-
[H:48]
4536-
[H:49]
4537-
[H:47]
4538-
[H:45]
4539-
[H:46]
4540-
[H:38]
4541-
[H:39]
4542-
[H:40]
4543-
[C-2]
4544-
[As+2]
4545-
[P+2]
4546-
[O+2]
4547-
[BeH2-]
4548-
[W@]
4549-
[W@@]
4550-
[B-2]
4551-
[V@]
4552-
[V@@]
4553-
[V@OH]
4376+
[TlH2+]
4377+
[SbH6+3]

0 commit comments

Comments
 (0)