Skip to content

Commit e3d3f99

Browse files
authored
Merge branch 'dev' into feature/chebi-from-list
2 parents 6ae1518 + 8734dc8 commit e3d3f99

37 files changed

+1167
-3013
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ chebai.egg-info
175175
lightning_logs
176176
logs
177177
.isort.cfg
178-
/.vscode
178+
/.vscode/launch.json
179179

180180
*.out
181181
*.err

.vscode/extensions.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"recommendations": [
3+
"ms-python.python",
4+
"ms-python.vscode-pylance",
5+
"charliermarsh.ruff",
6+
"usernamehw.errorlens"
7+
],
8+
"unwantedRecommendations": [
9+
"ms-python.vscode-python2"
10+
]
11+
}

.vscode/settings.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"python.testing.unittestArgs": [
3+
"-v",
4+
"-s",
5+
"./tests",
6+
"-p",
7+
"test*.py"
8+
],
9+
"python.testing.pytestEnabled": false,
10+
"python.testing.unittestEnabled": true,
11+
"python.analysis.typeCheckingMode": "basic",
12+
"editor.formatOnSave": true,
13+
"[python]": {
14+
"editor.defaultFormatter": "charliermarsh.ruff"
15+
}
16+
}

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf
6363
```
6464
A command with additional options may look like this:
6565
```
66-
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
66+
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_weighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
6767
```
6868

6969
### Fine-tuning for classification tasks, e.g. Toxicity prediction
@@ -78,11 +78,16 @@ python -m chebai fit --config=[path-to-your-esol-config] --trainer.callbacks=con
7878

7979
### Predicting classes given SMILES strings
8080
```
81-
python3 -m chebai predict_from_file --model=[path-to-model-config] --checkpoint_path=[path-to-model] --input_path={path-to-file-containing-smiles] [--classes_path=[path-to-classes-file]] [--save_to=[path-to-output]]
81+
python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-to-model] --smiles_file_path=[path-to-file-containing-smiles] [--save_to=[path-to-output]]
8282
```
83-
The input files should contain a list of line-separated SMILES strings. This generates a CSV file that contains the
84-
one row for each SMILES string and one column for each class.
85-
The `classes_path` is the path to the dataset's `raw/classes.txt` file that contains the relationship between model output and ChEBI-IDs.
83+
84+
* **`--checkpoint_path`**: Path to the Lightning checkpoint file (must end with `.ckpt`).
85+
86+
* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line.
87+
88+
* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`.
89+
90+
> **Note**: Newly created checkpoints after PR #148 must be used for this prediction pipeline. The list of ChEBI classes (classification labels) used during training is stored in new checkpoints, which are required.
8691
8792
## Evaluation
8893

@@ -96,7 +101,7 @@ An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
96101
Alternatively, you can evaluate the model via the CLI:
97102

98103
```bash
99-
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
104+
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce_weighted.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
100105
```
101106

102107
> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once.

chebai/cli.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]):
5959
apply_on="instantiate",
6060
)
6161

62+
parser.link_arguments(
63+
"data.classes_txt_file_path",
64+
"model.init_args.classes_txt_file_path",
65+
apply_on="instantiate",
66+
)
67+
6268
for kind in ("train", "val", "test"):
6369
for average in (
6470
"micro-f1",
@@ -111,8 +117,6 @@ def subcommands() -> Dict[str, Set[str]]:
111117
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
112118
"validate": {"model", "dataloaders", "datamodule"},
113119
"test": {"model", "dataloaders", "datamodule"},
114-
"predict": {"model", "dataloaders", "datamodule"},
115-
"predict_from_file": {"model"},
116120
}
117121

118122

chebai/models/base.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
pass_loss_kwargs: bool = True,
4141
optimizer_kwargs: Optional[Dict[str, Any]] = None,
4242
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
43+
classes_txt_file_path: Optional[str] = None,
4344
**kwargs,
4445
):
4546
super().__init__(**kwargs)
@@ -62,6 +63,7 @@ def __init__(
6263
"train_metrics",
6364
"val_metrics",
6465
"test_metrics",
66+
"classes_txt_file_path",
6567
*exclude_hyperparameter_logging,
6668
]
6769
)
@@ -78,6 +80,23 @@ def __init__(
7880
self.test_metrics = test_metrics
7981
self.pass_loss_kwargs = pass_loss_kwargs
8082

83+
self.classes_txt_file_path = classes_txt_file_path
84+
85+
# During prediction `classes_txt_file_path` is set to None
86+
if classes_txt_file_path is not None:
87+
with open(classes_txt_file_path, "r") as f:
88+
self.labels_list = [cls.strip() for cls in f.readlines()]
89+
assert len(self.labels_list) > 0, "Class labels list is empty."
90+
assert len(self.labels_list) == out_dim, (
91+
f"Number of class labels ({len(self.labels_list)}) does not match "
92+
f"the model output dimension ({out_dim})."
93+
)
94+
95+
def on_save_checkpoint(self, checkpoint):
96+
if self.classes_txt_file_path is not None:
97+
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere
98+
checkpoint["classification_labels"] = self.labels_list
99+
81100
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
82101
# avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a
83102
# different loss)
@@ -100,7 +119,7 @@ def __init_subclass__(cls, **kwargs):
100119

101120
def _get_prediction_and_labels(
102121
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
103-
) -> (torch.Tensor, torch.Tensor):
122+
) -> tuple[torch.Tensor, torch.Tensor]:
104123
"""
105124
Gets the predictions and labels from the model output.
106125
@@ -151,7 +170,7 @@ def _process_for_loss(
151170
model_output: torch.Tensor,
152171
labels: torch.Tensor,
153172
loss_kwargs: Dict[str, Any],
154-
) -> (torch.Tensor, torch.Tensor, Dict[str, Any]):
173+
) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
155174
"""
156175
Processes the data for loss computation.
157176
@@ -237,7 +256,7 @@ def predict_step(
237256
Returns:
238257
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
239258
"""
240-
return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False)
259+
return self._execute(batch, batch_idx, log=False)
241260

242261
def _execute(
243262
self,
@@ -324,8 +343,6 @@ def _execute(
324343
for metric_name, metric in metrics.items():
325344
metric.update(pr, tar)
326345
self._log_metrics(prefix, metrics, len(batch))
327-
if isinstance(d, dict) and "loss" not in d:
328-
print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}")
329346
return d
330347

331348
def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):

chebai/models/electra.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class ElectraPre(ChebaiBaseNet):
3939
replace_p (float): Probability of replacing tokens during training.
4040
"""
4141

42-
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
43-
super().__init__(config=config, **kwargs)
42+
def __init__(self, config: Dict[str, Any], **kwargs: Any):
43+
super().__init__(**kwargs)
4444

4545
self.generator_config = ElectraConfig(**config["generator"])
4646
self.generator = ElectraForMaskedLM(self.generator_config)
@@ -203,6 +203,7 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any
203203
)
204204
* CLS_TOKEN
205205
)
206+
model_kwargs["output_attentions"] = True
206207
return dict(
207208
features=torch.cat((cls_tokens, batch.x), dim=1),
208209
labels=batch.y,

chebai/preprocessing/bin/smiles_token/tokens.txt

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4375,3 +4375,149 @@ b
43754375
[OH2]
43764376
[TlH2+]
43774377
[SbH6+3]
4378+
[1*]
4379+
[2*]
4380+
[3*]
4381+
[4*]
4382+
[5*]
4383+
[6*]
4384+
[7*]
4385+
[8*]
4386+
[9*]
4387+
[3He+]
4388+
[12C+4]
4389+
[16O+6]
4390+
[11B-3]
4391+
[11B+3]
4392+
[31P+3]
4393+
[31P+5]
4394+
[34S+2]
4395+
[34S+4]
4396+
[34S+6]
4397+
[55Mn+2]
4398+
[55Mn+4]
4399+
[55Mn+7]
4400+
[57Fe+3]
4401+
[59Co+2]
4402+
[75As-3]
4403+
[98Mo+3]
4404+
[98Mo+6]
4405+
[Cl:1]
4406+
[c:2]
4407+
[n:3]
4408+
[c:4]
4409+
[c:5]
4410+
[H:24]
4411+
[c:6]
4412+
[H:25]
4413+
[c:7]
4414+
[H:26]
4415+
[c:8]
4416+
[H:27]
4417+
[c:9]
4418+
[c:10]
4419+
[H:28]
4420+
[c:11]
4421+
[C:12]
4422+
[O:13]
4423+
[c:14]
4424+
[c:15]
4425+
[H:31]
4426+
[c:16]
4427+
[H:32]
4428+
[c:17]
4429+
[H:33]
4430+
[c:18]
4431+
[c:19]
4432+
[H:34]
4433+
[c:20]
4434+
[H:35]
4435+
[c:21]
4436+
[H:36]
4437+
[n:22]
4438+
[c:23]
4439+
[H:29]
4440+
[H:30]
4441+
[C:1]
4442+
[C:2]
4443+
[O:3]
4444+
[O:4]
4445+
[H:41]
4446+
[H:42]
4447+
[H:43]
4448+
[H:44]
4449+
[C:11]
4450+
[O:12]
4451+
[C:14]
4452+
[C:15]
4453+
[O:16]
4454+
[N:17]
4455+
[C:18]
4456+
[C:19]
4457+
[H:50]
4458+
[C:20]
4459+
[H:51]
4460+
[H:52]
4461+
[N:21]
4462+
[c:25]
4463+
[c:26]
4464+
[H:53]
4465+
[c:27]
4466+
[F:37]
4467+
[c:28]
4468+
[N:31]
4469+
[C:32]
4470+
[H:56]
4471+
[H:57]
4472+
[C:33]
4473+
[H:58]
4474+
[H:59]
4475+
[O:34]
4476+
[C:35]
4477+
[H:60]
4478+
[H:61]
4479+
[C:36]
4480+
[H:62]
4481+
[H:63]
4482+
[c:29]
4483+
[H:54]
4484+
[c:30]
4485+
[H:55]
4486+
[C:22]
4487+
[O:23]
4488+
[O:24]
4489+
[H:48]
4490+
[H:49]
4491+
[H:47]
4492+
[H:45]
4493+
[H:46]
4494+
[H:38]
4495+
[H:39]
4496+
[H:40]
4497+
[NaH2-]
4498+
[KH2-]
4499+
[C-2]
4500+
[As+2]
4501+
[P+2]
4502+
[LiH2-]
4503+
[BH2-3]
4504+
[O+2]
4505+
[BeH2-]
4506+
[W@]
4507+
[W@@]
4508+
[RbH2-]
4509+
[FrH2-]
4510+
[AlH-2]
4511+
[CsH2-]
4512+
[B-2]
4513+
[V@]
4514+
[V@@]
4515+
[V@OH]
4516+
[*:0]
4517+
[1*:0]
4518+
[2*:0]
4519+
[3*:0]
4520+
[224RaH2]
4521+
[226RaH2]
4522+
[228RaH2]
4523+
[*-:0]

chebai/preprocessing/collate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
8787
*((d["features"], d["labels"], d.get("ident")) for d in data)
8888
)
8989
missing_labels = [
90-
d.get("missing_labels", [False for _ in y[0]]) for d in data
90+
d.get(
91+
"missing_labels",
92+
[False for _ in y[0]] if y[0] is not None else [False],
93+
)
94+
for d in data
9195
]
9296

9397
if any(x is not None for x in y):

0 commit comments

Comments
 (0)