Skip to content

Commit 679a5cb

Browse files
Copilotaditya0by0
andauthored
Fix review comments: type hints, error handling, and documentation (#152)
* Initial plan * Address review comments from PR #148 Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: aditya0by0 <65857172+aditya0by0@users.noreply.github.com>
1 parent fea402e commit 679a5cb

File tree

4 files changed

+8
-9
lines changed

4 files changed

+8
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ python -m chebai fit --trainer=configs/training/default_trainer.yml --model=conf
6363
```
6464
A command with additional options may look like this:
6565
```
66-
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_unweighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
66+
python3 -m chebai fit --trainer=configs/training/default_trainer.yml --model=configs/model/electra.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --model.pretrained_checkpoint=electra_pretrained.ckpt --model.load_prefix=generator. --data=configs/data/chebi/chebi50.yml --model.criterion=configs/loss/bce_weighted.yml --data.init_args.batch_size=10 --trainer.logger.init_args.name=chebi50_bce_weighted --data.init_args.num_workers=9 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --data.init_args.data_limit=1000
6767
```
6868

6969
### Fine-tuning for classification tasks, e.g. Toxicity prediction

chebai/preprocessing/datasets/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def _process_input_for_prediction(
445445
446446
Args:
447447
smiles_list (List[str]): List of SMILES strings.
448-
model_hparams (Optional[dict]): Model hyperparameters.
448+
model_hparams (dict): Model hyperparameters.
449449
Some prediction pre-processing pipelines may require these.
450450
451451
Returns:
@@ -467,7 +467,7 @@ def _process_input_for_prediction(
467467
return data, valid_indices
468468

469469
def _preprocess_smiles_for_pred(
470-
self, idx, smiles: str, model_hparams: Optional[dict] = None
470+
self, idx: int, smiles: str, model_hparams: Optional[dict] = None
471471
) -> dict:
472472
"""Preprocess prediction data."""
473473
# Add dummy labels because the collate function requires them.

chebai/preprocessing/migration/migrate_checkpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def add_class_labels_to_checkpoint(input_path, classes_file_path):
4141

4242

4343
if __name__ == "__main__":
44-
if len(sys.argv) < 2:
45-
print("Usage: python modify_checkpoints.py <input_checkpoint> <classes_file>")
44+
if len(sys.argv) < 3:
45+
print("Usage: python migrate_checkpoints.py <input_checkpoint> <classes_file>")
4646
sys.exit(1)
4747

4848
input_ckpt = sys.argv[1]

chebai/result/prediction.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,8 @@ def __init__(
7171
self._model.to(self.device)
7272
print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}")
7373

74-
try:
75-
self._classification_labels: list = ckpt_file.get("classification_labels")
76-
except KeyError:
74+
self._classification_labels: list = ckpt_file.get("classification_labels")
75+
if self._classification_labels is None:
7776
raise KeyError(
7877
"The checkpoint does not contain 'classification_labels'. "
7978
"Make sure the checkpoint is compatible with python-chebai version 1.2.1 or later."
@@ -140,7 +139,7 @@ def predict_smiles(
140139
Returns:
141140
A tensor containing the predictions.
142141
"""
143-
# For certain data prediction piplines, we may need model hyperparameters
142+
# For certain data prediction pipelines, we may need model hyperparameters
144143
pred_dl, valid_indices = self._dm.predict_dataloader(
145144
smiles_list=smiles, model_hparams=self._model_hparams
146145
)

0 commit comments

Comments
 (0)