Skip to content

Commit a1cdaca

Browse files
committed
update for prediction logic how ckpts with class labels
1 parent 0ea3490 commit a1cdaca

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ python3 chebai/result/prediction.py predict_from_file --checkpoint_path=[path-t
8181

8282
* **`--smiles_file_path`**: Path to a text file containing one SMILES string per line.
8383

84-
* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class.
85-
86-
* **`--classes_path`** *(optional)*: Path to the dataset’s `raw/classes.txt` file, which maps model output indices to ChEBI IDs.
84+
* **`--save_to`** *(optional)*: Predictions will be saved to the path as CSV file. The CSV will contain one row per SMILES string and one column per predicted class. Default path will be the current working directory with file name as `predictions.csv`.
8785

86+
* **`--classes_path`** *(optional)*: Path to the dataset’s `classes.txt` file, which maps model output indices to ChEBI IDs.
87+
* Checkpoints created after PR #135 will have the classification labels stored in them and hence this parameter is not required.
8888
* If provided, the CSV columns will be named using the ChEBI IDs.
8989
* If omitted, then script will located the file automatically. If unable to locate then the columns will be numbered sequentially.
9090

chebai/result/prediction.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ def __init__(
6262
)
6363
print("*" * 10, f"Loaded model class: {self._model.__class__.__name__}")
6464

65+
self._classification_labels: list | None = ckpt_file.get(
66+
"classification_labels", None
67+
)
68+
if self._classification_labels is not None:
69+
print(f"Loaded {len(self._classification_labels)} classification labels.")
70+
assert len(self._classification_labels) > 0, (
71+
"Classification labels list is empty."
72+
)
73+
assert len(self._classification_labels) == self._model.out_dim, (
74+
f"Number of class labels ({len(self._classification_labels)}) does not match "
75+
f"the model output dimension ({self._model.out_dim})."
76+
)
77+
6578
if compile_model:
6679
self._model = torch.compile(self._model)
6780
self._model.eval()
@@ -92,7 +105,10 @@ def _add_class_columns(class_file_path: _PATH) -> list[str]:
92105
with open(class_file_path, "r") as f:
93106
return [cls.strip() for cls in f.readlines()]
94107

95-
if classes_path is not None:
108+
if self._classification_labels is not None:
109+
CLASS_LABELS = self._classification_labels
110+
# --- For old checkpoints that do not have classification_labels saved ---
111+
elif classes_path is not None:
96112
CLASS_LABELS = _add_class_columns(classes_path)
97113
elif os.path.exists(self._dm.classes_txt_file_path):
98114
CLASS_LABELS = _add_class_columns(self._dm.classes_txt_file_path)
@@ -102,6 +118,7 @@ def _add_class_columns(class_file_path: _PATH) -> list[str]:
102118
print("No valid predictions were made. (All predictions are None.)")
103119
return
104120

121+
# --- Logic for old checkpoints that do not have classification_labels saved ---
105122
if CLASS_LABELS is not None and self._model.out_dim is not None:
106123
assert len(CLASS_LABELS) > 0, "Class labels list is empty."
107124
assert len(CLASS_LABELS) == self._model.out_dim, (

0 commit comments

Comments
 (0)