-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathbase.py
More file actions
413 lines (358 loc) · 15.6 KB
/
base.py
File metadata and controls
413 lines (358 loc) · 15.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Optional, Union
import torch
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.utilities.rank_zero import rank_zero_info
from chebai.preprocessing.structures import XYData
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
_MODEL_REGISTRY = dict()
class ChebaiBaseNet(LightningModule, ABC):
"""
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
Args:
criterion (torch.nn.Module, optional): The loss criterion for the model. Defaults to None.
out_dim (int, optional): The output dimension of the model. Defaults to None.
train_metrics (torch.nn.Module, optional): The metrics to be used during training. Defaults to None.
val_metrics (torch.nn.Module, optional): The metrics to be used during validation. Defaults to None.
test_metrics (torch.nn.Module, optional): The metrics to be used during testing. Defaults to None.
pass_loss_kwargs (bool, optional): Whether to pass loss kwargs to the criterion. Defaults to True.
optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
criterion: torch.nn.Module = None,
out_dim: Optional[int] = None,
input_dim: Optional[int] = None,
train_metrics: Optional[torch.nn.Module] = None,
val_metrics: Optional[torch.nn.Module] = None,
test_metrics: Optional[torch.nn.Module] = None,
pass_loss_kwargs: bool = True,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
exclude_hyperparameter_logging: Optional[Iterable[str]] = None,
classes_txt_file_path: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
# super().__init__()
if exclude_hyperparameter_logging is None:
exclude_hyperparameter_logging = tuple()
self.criterion = criterion
assert out_dim is not None and out_dim > 0, "out_dim must be specified"
assert input_dim is not None and input_dim > 0, "input_dim must be specified"
self.out_dim = out_dim
self.input_dim = input_dim
print(
f"Input dimension for the model: {self.input_dim}",
f"Output dimension for the model: {self.out_dim}",
)
self.save_hyperparameters(
ignore=[
"criterion",
"train_metrics",
"val_metrics",
"test_metrics",
*exclude_hyperparameter_logging,
]
)
self.hparams["out_dim"] = out_dim
self.hparams["input_dim"] = input_dim
if optimizer_kwargs:
self.optimizer_kwargs = optimizer_kwargs
else:
self.optimizer_kwargs = dict()
self.train_metrics = train_metrics
self.validation_metrics = val_metrics
self.test_metrics = test_metrics
self.pass_loss_kwargs = pass_loss_kwargs
with open(classes_txt_file_path, "r") as f:
self.labels_list = [cls.strip() for cls in f.readlines()]
assert len(self.labels_list) > 0, "Class labels list is empty."
assert len(self.labels_list) == out_dim, (
f"Number of class labels ({len(self.labels_list)}) does not match "
f"the model output dimension ({out_dim})."
)
def on_save_checkpoint(self, checkpoint):
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html#modify-a-checkpoint-anywhere
checkpoint["classification_labels"] = self.labels_list
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a
# different loss)
if "criterion.base_loss.pos_weight" in checkpoint["state_dict"]:
del checkpoint["state_dict"]["criterion.base_loss.pos_weight"]
if "criterion.pos_weight" in checkpoint["state_dict"]:
del checkpoint["state_dict"]["criterion.pos_weight"]
def __init_subclass__(cls, **kwargs):
"""
Automatically registers subclasses in the model registry to prevent duplicates.
Args:
**kwargs: Additional keyword arguments.
"""
if cls.__name__ in _MODEL_REGISTRY:
raise ValueError(f"Model {cls.__name__} does already exist")
else:
_MODEL_REGISTRY[cls.__name__] = cls
def _get_prediction_and_labels(
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gets the predictions and labels from the model output.
Args:
data (Dict[str, Any]): The processed batch data.
labels (torch.Tensor): The true labels.
output (torch.Tensor): The model output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Predictions and labels.
"""
# cast labels to int
return output, labels.to(torch.int) if labels is not None else labels
def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor:
"""
Processes the labels in the batch.
Args:
batch (XYData): The input batch of data.
Returns:
torch.Tensor: The processed labels.
"""
return batch.y.float()
def _process_batch(self, batch: XYData, batch_idx: int) -> Dict[str, Any]:
"""
Processes the batch data.
Args:
batch (XYData): The input batch of data.
batch_idx (int): The index of the current batch.
Returns:
Dict[str, Any]: Processed batch data.
"""
return dict(
features=batch.x,
labels=self._process_labels_in_batch(batch),
model_kwargs=batch.additional_fields["model_kwargs"],
loss_kwargs=batch.additional_fields["loss_kwargs"],
idents=batch.additional_fields["idents"],
)
def _process_for_loss(
self,
model_output: torch.Tensor,
labels: torch.Tensor,
loss_kwargs: Dict[str, Any],
) -> tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Processes the data for loss computation.
Args:
model_output (torch.Tensor): The model output.
labels (torch.Tensor): The true labels.
loss_kwargs (Dict[str, Any]): Additional keyword arguments for the loss function.
Returns:
Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: Model output, labels, and loss kwargs.
"""
return model_output, labels, loss_kwargs
def on_train_epoch_start(self) -> None:
# pass current epoch to datamodule if it has the attribute curr_epoch (for PubChemBatched dataset)
rank_zero_info(f"Starting epoch {self.current_epoch}")
if hasattr(self.trainer.datamodule, "curr_epoch"):
rank_zero_info(f"Setting datamodule.curr_epoch to {self.current_epoch}")
self.trainer.datamodule.curr_epoch = self.current_epoch
def training_step(
self, batch: XYData, batch_idx: int
) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Defines the training step.
Args:
batch (XYData): The input batch of data.
batch_idx (int): The index of the current batch.
Returns:
Dict[str, Union[torch.Tensor, Any]]: The result of the training step.
"""
return self._execute(
batch, batch_idx, self.train_metrics, prefix="train_", sync_dist=True
)
def validation_step(
self, batch: XYData, batch_idx: int
) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Defines the validation step.
Args:
batch (XYData): The input batch of data.
batch_idx (int): The index of the current batch.
Returns:
Dict[str, Union[torch.Tensor, Any]]: The result of the validation step.
"""
return self._execute(
batch, batch_idx, self.validation_metrics, prefix="val_", sync_dist=True
)
def test_step(
self, batch: XYData, batch_idx: int
) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Defines the test step.
Args:
batch (XYData): The input batch of data.
batch_idx (int): The index of the current batch.
Returns:
Dict[str, Union[torch.Tensor, Any]]: The result of the test step.
"""
return self._execute(
batch, batch_idx, self.test_metrics, prefix="test_", sync_dist=True
)
def predict_step(
self, batch: XYData, batch_idx: int, **kwargs
) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Defines the prediction step.
Args:
batch (XYData): The input batch of data.
batch_idx (int): The index of the current batch.
**kwargs: Additional keyword arguments.
Returns:
Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step.
"""
assert isinstance(batch, XYData)
batch = batch.to(self.device)
data = self._process_batch(batch, batch_idx)
model_output = self(data, **data.get("model_kwargs", dict()))
# Dummy labels to avoid errors in _get_prediction_and_labels
labels = torch.zeros((len(batch), self.out_dim)).to(self.device)
pr, _ = self._get_prediction_and_labels(data, labels, model_output)
return pr
def _execute(
self,
batch: XYData,
batch_idx: int,
metrics: Optional[torch.nn.Module] = None,
prefix: Optional[str] = "",
log: Optional[bool] = True,
sync_dist: Optional[bool] = False,
) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Executes the model on a batch of data and returns the model output and predictions.
Args:
batch (XYData): The input batch of data.
batch_idx (int): The index of the current batch.
metrics (torch.nn.Module): A dictionary of metrics to track.
prefix (str, optional): A prefix to add to the metric names. Defaults to "".
log (bool, optional): Whether to log the metrics. Defaults to True.
sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
Returns:
Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output,
predictions, and loss (if applicable).
"""
assert isinstance(batch, XYData)
batch = batch.to(self.device)
data = self._process_batch(batch, batch_idx)
labels = data["labels"]
model_output = self(data, **data.get("model_kwargs", dict()))
pr, tar = self._get_prediction_and_labels(data, labels, model_output)
d = dict(data=data, labels=labels, output=model_output, preds=pr)
if log:
if self.criterion is not None:
loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss(
model_output, labels, data.get("loss_kwargs", dict())
)
loss_kwargs = dict()
if self.pass_loss_kwargs:
loss_kwargs = loss_kwargs_candidates
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
if isinstance(loss, tuple):
unnamed_loss_index = 1
if isinstance(loss[1], dict):
unnamed_loss_index = 2
for key, value in loss[1].items():
self.log(
key,
value if isinstance(value, int) else value.item(),
batch_size=len(batch),
on_step=True,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=sync_dist,
)
loss_additional = loss[unnamed_loss_index:]
for i, loss_add in enumerate(loss_additional):
self.log(
f"{prefix}loss_{i}",
loss_add if isinstance(loss_add, int) else loss_add.item(),
batch_size=len(batch),
on_step=True,
on_epoch=True,
prog_bar=False,
logger=True,
sync_dist=sync_dist,
)
loss = loss[0]
d["loss"] = loss
self.log(
f"{prefix}loss",
loss.item(),
batch_size=len(batch),
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=sync_dist,
)
if metrics and labels is not None:
for metric_name, metric in metrics.items():
metric.update(pr, tar)
self._log_metrics(prefix, metrics, len(batch))
if isinstance(d, dict) and "loss" not in d:
print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}")
return d
def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
"""
Logs the metrics for the given prefix.
Args:
prefix (str): The prefix to be added to the metric names.
metrics (torch.nn.Module): A dictionary containing the metrics to be logged.
batch_size (int): The batch size used for logging.
Returns:
None
"""
# don't use sync_dist=True if the metric is a torchmetrics-metric
# (see https://github.com/Lightning-AI/pytorch-lightning/discussions/6501#discussioncomment-569757)
for metric_name, metric in metrics.items():
m = None # m = metric.compute()
if isinstance(m, dict):
# todo: is this case needed? it requires logging values directly which does not give accurate results
# with the current metric-setup
for k, m2 in m.items():
self.log(
f"{prefix}{metric_name}{k}",
m2,
batch_size=batch_size,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
else:
self.log(
f"{prefix}{metric_name}",
metric,
batch_size=batch_size,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
@abstractmethod
def forward(self, x: Dict[str, Any]) -> torch.Tensor:
"""
Defines the forward pass.
Args:
x (Dict[str, Any]): The input data.
Returns:
torch.Tensor: The model output.
"""
pass
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
"""
Configures the optimizers.
Args:
**kwargs: Additional keyword arguments.
Returns:
torch.optim.Optimizer: The optimizer.
"""
return torch.optim.Adamax(self.parameters(), **self.optimizer_kwargs)