Skip to content

Commit 0ea3490

Browse files
committed
save classification labels to checkpoints
1 parent 811fbf1 commit 0ea3490

File tree

4 files changed

+46
-25
lines changed

4 files changed

+46
-25
lines changed

chebai/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def call_data_methods(data: Type[XYBaseDataModule]):
5959
apply_on="instantiate",
6060
)
6161

62+
parser.link_arguments(
63+
"data.classes_txt_file_path",
64+
"model.init_args.classes_txt_file_path",
65+
apply_on="instantiate",
66+
)
67+
6268
for kind in ("train", "val", "test"):
6369
for average in (
6470
"micro-f1",

chebai/models/base.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
pass_loss_kwargs: bool = True,
4141
optimizer_kwargs: Optional[Dict[str, Any]] = None,
4242
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
43+
classes_txt_file_path: Optional[str] = None,
4344
**kwargs,
4445
):
4546
super().__init__(**kwargs)
@@ -77,6 +78,17 @@ def __init__(
7778
self.validation_metrics = val_metrics
7879
self.test_metrics = test_metrics
7980
self.pass_loss_kwargs = pass_loss_kwargs
81+
with open(classes_txt_file_path, "r") as f:
82+
self.labels_list = [cls.strip() for cls in f.readlines()]
83+
assert len(self.labels_list) > 0, "Class labels list is empty."
84+
assert len(self.labels_list) == out_dim, (
85+
f"Number of class labels ({len(self.labels_list)}) does not match "
86+
f"the model output dimension ({out_dim})."
87+
)
88+
89+
def on_save_checkpoint(self, checkpoint):
90+
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere
91+
checkpoint["classification_labels"] = self.labels_list
8092

8193
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
8294
# avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a
@@ -100,7 +112,7 @@ def __init_subclass__(cls, **kwargs):
100112

101113
def _get_prediction_and_labels(
102114
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
103-
) -> (torch.Tensor, torch.Tensor):
115+
) -> tuple[torch.Tensor, torch.Tensor]:
104116
"""
105117
Gets the predictions and labels from the model output.
106118
@@ -151,7 +163,7 @@ def _process_for_loss(
151163
model_output: torch.Tensor,
152164
labels: torch.Tensor,
153165
loss_kwargs: Dict[str, Any],
154-
) -> (torch.Tensor, torch.Tensor, Dict[str, Any]):
166+
) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
155167
"""
156168
Processes the data for loss computation.
157169

chebai/preprocessing/datasets/base.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def __init__(
9696
self.prediction_kind = prediction_kind
9797
self.data_limit = data_limit
9898
self.label_filter = label_filter
99-
assert (balance_after_filter is not None) or (
100-
self.label_filter is None
101-
), "Filter balancing requires a filter"
99+
assert (balance_after_filter is not None) or (self.label_filter is None), (
100+
"Filter balancing requires a filter"
101+
)
102102
self.balance_after_filter = balance_after_filter
103103
self.num_workers = num_workers
104104
self.persistent_workers: bool = bool(persistent_workers)
@@ -108,13 +108,13 @@ def __init__(
108108
self.use_inner_cross_validation = (
109109
inner_k_folds > 1
110110
) # only use cv if there are at least 2 folds
111-
assert (
112-
fold_index is None or self.use_inner_cross_validation is not None
113-
), "fold_index can only be set if cross validation is used"
111+
assert fold_index is None or self.use_inner_cross_validation is not None, (
112+
"fold_index can only be set if cross validation is used"
113+
)
114114
if fold_index is not None and self.inner_k_folds is not None:
115-
assert (
116-
fold_index < self.inner_k_folds
117-
), "fold_index can't be larger than the total number of folds"
115+
assert fold_index < self.inner_k_folds, (
116+
"fold_index can't be larger than the total number of folds"
117+
)
118118
self.fold_index = fold_index
119119
self._base_dir = base_dir
120120
self.n_token_limit = n_token_limit
@@ -137,9 +137,9 @@ def num_of_labels(self):
137137

138138
@property
139139
def feature_vector_size(self):
140-
assert (
141-
self._feature_vector_size is not None
142-
), "size of feature vector must be set"
140+
assert self._feature_vector_size is not None, (
141+
"size of feature vector must be set"
142+
)
143143
return self._feature_vector_size
144144

145145
@property
@@ -619,6 +619,19 @@ def raw_file_names_dict(self) -> dict:
619619
"""
620620
raise NotImplementedError
621621

622+
@property
623+
def classes_txt_file_path(self) -> str:
624+
"""
625+
Returns the filename for the classes text file.
626+
627+
Returns:
628+
str: The filename for the classes text file.
629+
"""
630+
# This property also used in following places:
631+
# - results/prediction.py: to load class names for csv columns names
632+
# - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path`
633+
return os.path.join(self.processed_dir_main, "classes.txt")
634+
622635

623636
class MergedDataset(XYBaseDataModule):
624637
MERGED = []
@@ -1373,14 +1386,3 @@ def processed_file_names_dict(self) -> dict:
13731386
if self.n_token_limit is not None:
13741387
return {"data": f"data_maxlen{self.n_token_limit}.pt"}
13751388
return {"data": "data.pt"}
1376-
1377-
@property
1378-
def classes_txt_file_path(self) -> str:
1379-
"""
1380-
Returns the filename for the classes text file.
1381-
1382-
Returns:
1383-
str: The filename for the classes text file.
1384-
"""
1385-
# This property also used in custom trainer `chebai/trainer/CustomTrainer.py`
1386-
return os.path.join(self.processed_dir_main, "classes.txt")

chebai/result/prediction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def _add_class_columns(class_file_path: _PATH) -> list[str]:
126126
predictions_df = pd.DataFrame(rows, columns=CLASS_LABELS, index=smiles_strings)
127127

128128
predictions_df.to_csv(save_to)
129+
print(f"Predictions saved to: {save_to}")
129130

130131
@torch.inference_mode()
131132
def predict_smiles(

0 commit comments

Comments
 (0)