Skip to content

Commit 3b8bbc7

Browse files
authored
added graceful OOM handling (#354)
* added graceful OOM handling * fix for pytorch <1.13
1 parent 1b2abf5 commit 3b8bbc7

File tree

4 files changed

+148
-55
lines changed

4 files changed

+148
-55
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@
5050
PreEncoded1dLayer,
5151
)
5252
from pytorch_tabular.tabular_datamodule import TabularDatamodule
53-
from pytorch_tabular.utils import get_logger, getattr_nested, pl_load
53+
from pytorch_tabular.utils import (
54+
OOMException,
55+
OutOfMemoryHandler,
56+
get_logger,
57+
getattr_nested,
58+
pl_load,
59+
)
5460

5561
try:
5662
import captum.attr
@@ -574,6 +580,7 @@ def train(
574580
callbacks: Optional[List[pl.Callback]] = None,
575581
max_epochs: int = None,
576582
min_epochs: int = None,
583+
handle_oom: bool = True,
577584
) -> pl.Trainer:
578585
"""Trains the model.
579586
@@ -589,6 +596,8 @@ def train(
589596
590597
min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.
591598
599+
handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.
600+
592601
Returns:
593602
pl.Trainer: The PyTorch Lightning Trainer instance
594603
"""
@@ -601,18 +610,36 @@ def train(
601610
if self.config.auto_lr_find and (not self.config.fast_dev_run):
602611
if self.verbose:
603612
logger.info("Auto LR Find Started")
604-
result = Tuner(self.trainer).lr_find(self.model, train_dataloaders=train_loader, val_dataloaders=val_loader)
613+
with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
614+
result = Tuner(self.trainer).lr_find(
615+
self.model,
616+
train_dataloaders=train_loader,
617+
val_dataloaders=val_loader,
618+
)
619+
if oom_handler.oom_triggered:
620+
raise OOMException(
621+
"OOM detected during LR Find. Try reducing your batch_size or the"
622+
" model parameters." + "/n" + "Original Error: " + oom_handler.oom_msg
623+
)
605624
if self.verbose:
606625
logger.info(
607626
f"Suggested LR: {result.suggestion()}. For plot and detailed"
608627
" analysis, use `find_learning_rate` method."
609628
)
629+
self.model.reset_weights()
610630
# Parameters in models needs to be initialized again after LR find
611631
self.model.data_aware_initialization(self.datamodule)
612632
self.model.train()
613633
if self.verbose:
614634
logger.info("Training Started")
615-
self.trainer.fit(self.model, train_loader, val_loader)
635+
with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
636+
self.trainer.fit(self.model, train_loader, val_loader)
637+
if oom_handler.oom_triggered:
638+
raise OOMException(
639+
"OOM detected during Training. Try reducing your batch_size or the"
640+
" model parameters."
641+
"/n" + "Original Error: " + oom_handler.oom_msg
642+
)
616643
self._is_fitted = True
617644
if self.verbose:
618645
logger.info("Training the model completed")
@@ -637,6 +664,7 @@ def fit(
637664
callbacks: Optional[List[pl.Callback]] = None,
638665
datamodule: Optional[TabularDatamodule] = None,
639666
cache_data: str = "memory",
667+
handle_oom: bool = True,
640668
) -> pl.Trainer:
641669
"""The fit method which takes in the data and triggers the training.
642670
@@ -690,6 +718,8 @@ def fit(
690718
cache_data (str): Decides how to cache the data in the dataloader. If set to
691719
"memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
692720
721+
handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.
722+
693723
Returns:
694724
pl.Trainer: The PyTorch Lightning Trainer instance
695725
"""
@@ -728,7 +758,7 @@ def fit(
728758
optimizer_params or {},
729759
)
730760

731-
return self.train(model, datamodule, callbacks, max_epochs, min_epochs)
761+
return self.train(model, datamodule, callbacks, max_epochs, min_epochs, handle_oom)
732762

733763
def pretrain(
734764
self,
@@ -1229,7 +1259,7 @@ def predict(
12291259

12301260
progress_bar = partial(tqdm, description="Generating Predictions...")
12311261
else:
1232-
progress_bar = lambda it: it
1262+
progress_bar = lambda it: it # noqa E731
12331263
for batch in progress_bar(inference_dataloader):
12341264
for k, v in batch.items():
12351265
if isinstance(v, list) and (len(v) == 0):
@@ -1293,8 +1323,9 @@ def predict(
12931323
np.argmax(point_predictions, axis=1)
12941324
)
12951325
warnings.warn(
1296-
"Classification prediction column will be renamed to `{target_col}_prediction` "
1297-
"in the next release to maintain consistency with regression.",
1326+
"Classification prediction column will be renamed to"
1327+
" `{target_col}_prediction` in the next release to maintain"
1328+
" consistency with regression.",
12981329
DeprecationWarning,
12991330
)
13001331
if ret_logits:
@@ -1710,6 +1741,7 @@ def cross_validate(
17101741
groups: Optional[Union[str, np.ndarray]] = None,
17111742
verbose: bool = True,
17121743
reset_datamodule: bool = True,
1744+
handle_oom: bool = True,
17131745
**kwargs,
17141746
):
17151747
"""Cross validate the model.
@@ -1753,6 +1785,7 @@ def cross_validate(
17531785
If False, we take an approximation that once the transformations are fit on the first
17541786
fold, they will be valid for all the other folds. Defaults to True.
17551787
1788+
handle_oom (bool, optional): If True, will handle out of memory errors elegantly
17561789
**kwargs: Additional keyword arguments to be passed to the `fit` method of the model.
17571790
17581791
Returns:
@@ -1789,7 +1822,8 @@ def cross_validate(
17891822
datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference")
17901823

17911824
# Train the model
1792-
self.train(model, datamodule, **train_kwargs)
1825+
handle_oom = train_kwargs.pop("handle_oom", handle_oom)
1826+
self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
17931827
if return_oof or is_callable_metric:
17941828
preds = self.predict(val_fold, include_input_features=False)
17951829
oof_preds.append(preds)
@@ -1864,6 +1898,7 @@ def bagging_predict(
18641898
return_raw_predictions: bool = False,
18651899
aggregate: Union[str, Callable] = "mean",
18661900
weights: Optional[List[float]] = None,
1901+
handle_oom: bool = True,
18671902
**kwargs,
18681903
):
18691904
"""Bagging predict on the test data.
@@ -1912,6 +1947,8 @@ def bagging_predict(
19121947
from each fold. If None, will use equal weights. This is only used when `aggregate` is "mean".
19131948
Defaults to None.
19141949
1950+
handle_oom (bool, optional): If True, will handle out of memory errors elegantly
1951+
19151952
**kwargs: Additional keyword arguments to be passed to the `fit` method of the model.
19161953
19171954
Returns:
@@ -1953,7 +1990,8 @@ def bagging_predict(
19531990
datamodule.validation, _ = datamodule.preprocess_data(val_fold, stage="inference")
19541991

19551992
# Train the model
1956-
self.train(model, datamodule, **train_kwargs)
1993+
handle_oom = train_kwargs.pop("handle_oom", handle_oom)
1994+
self.train(model, datamodule, handle_oom=handle_oom, **train_kwargs)
19571995
fold_preds = self.predict(test, include_input_features=False)
19581996
pred_idx = fold_preds.index
19591997
if self.config.task == "classification":

src/pytorch_tabular/tabular_model_tuner.py

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""Tabular Model."""
55
import warnings
66
from collections import namedtuple
7-
from contextlib import nullcontext
87
from copy import deepcopy
98
from pathlib import Path
109
from typing import Callable, Dict, Iterable, Optional, Union
@@ -13,7 +12,7 @@
1312
import pandas as pd
1413
from omegaconf.dictconfig import DictConfig
1514
from pandas import DataFrame
16-
from rich.progress import Progress
15+
from rich.progress import track
1716
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
1817

1918
from pytorch_tabular.config import (
@@ -23,7 +22,7 @@
2322
TrainerConfig,
2423
)
2524
from pytorch_tabular.tabular_model import TabularModel
26-
from pytorch_tabular.utils import get_logger
25+
from pytorch_tabular.utils import OOMException, OutOfMemoryHandler, get_logger
2726

2827
logger = get_logger(__name__)
2928

@@ -146,6 +145,7 @@ def tune(
146145
verbose: bool = False,
147146
progress_bar: bool = True,
148147
random_state: Optional[int] = 42,
148+
ignore_oom: bool = True,
149149
**kwargs,
150150
):
151151
"""Tune the hyperparameters of the TabularModel.
@@ -194,6 +194,8 @@ def tune(
194194
195195
random_state (Optional[int], optional): Random state to be used for random search. Defaults to 42.
196196
197+
ignore_oom (bool, optional): Whether to ignore out of memory errors. Defaults to True.
198+
197199
**kwargs: Additional keyword arguments to be passed to the TabularModel fit.
198200
199201
Returns:
@@ -230,9 +232,7 @@ def tune(
230232
else:
231233
raise NotImplementedError(f"{strategy} is not implemented yet.")
232234
if progress_bar:
233-
ctx_mgr = Progress()
234-
else:
235-
ctx_mgr = nullcontext()
235+
iterator = track(iterator, description=f"[green]{strategy.replace('_',' ').title()}...", total=n_trials)
236236
verbose_tabular_model = self.tabular_model_init_kwargs.pop("verbose", False)
237237
temp_tabular_model = TabularModel(
238238
data_config=self.data_config,
@@ -253,58 +253,74 @@ def tune(
253253
is_callable_metric = True
254254
del temp_tabular_model
255255
trials = []
256+
for i, params in enumerate(iterator):
257+
# Copying the configs as a base
258+
# Make sure all default parameters that you want to be set for all
259+
# trials are in the original configs
260+
trainer_config_t = deepcopy(self.trainer_config)
261+
optimizer_config_t = deepcopy(self.optimizer_config)
262+
model_config_t = deepcopy(self.model_config)
256263

257-
with ctx_mgr as progress:
258-
if progress:
259-
task = progress.add_task(f"[green]{strategy.replace('_',' ').title()}...", total=n_trials)
260-
for i, params in enumerate(iterator):
261-
# Copying the configs as a base
262-
# Make sure all default parameters that you want to be set for all
263-
# trials are in the original configs
264-
trainer_config_t = deepcopy(self.trainer_config)
265-
optimizer_config_t = deepcopy(self.optimizer_config)
266-
model_config_t = deepcopy(self.model_config)
267-
268-
trainer_config_t, optimizer_config_t, model_config_t = self._update_configs(
269-
trainer_config_t, optimizer_config_t, model_config_t, params
270-
)
271-
# Initialize Tabular model using the new config
272-
tabular_model_t = TabularModel(
273-
data_config=self.data_config,
274-
model_config=model_config_t,
275-
optimizer_config=optimizer_config_t,
276-
trainer_config=trainer_config_t,
277-
verbose=verbose_tabular_model,
278-
**self.tabular_model_init_kwargs,
279-
)
280-
if cv is not None:
281-
cv_verbose = cv_kwargs.pop("verbose", False)
264+
trainer_config_t, optimizer_config_t, model_config_t = self._update_configs(
265+
trainer_config_t, optimizer_config_t, model_config_t, params
266+
)
267+
# Initialize Tabular model using the new config
268+
tabular_model_t = TabularModel(
269+
data_config=self.data_config,
270+
model_config=model_config_t,
271+
optimizer_config=optimizer_config_t,
272+
trainer_config=trainer_config_t,
273+
verbose=verbose_tabular_model,
274+
**self.tabular_model_init_kwargs,
275+
)
276+
if cv is not None:
277+
cv_verbose = cv_kwargs.pop("verbose", False)
278+
cv_kwargs.pop("handle_oom", None)
279+
with OutOfMemoryHandler(handle_oom=True) as handler:
282280
cv_scores, _ = tabular_model_t.cross_validate(
283281
cv=cv,
284282
train=train,
285283
metric=metric,
286284
verbose=cv_verbose,
285+
handle_oom=False,
287286
**cv_kwargs,
288287
)
288+
if handler.oom_triggered:
289+
if not ignore_oom:
290+
raise OOMException(
291+
"Out of memory error occurred during cross validation. "
292+
"Set ignore_oom=True to ignore this error."
293+
)
294+
else:
295+
params.update({metric.__name__ if is_callable_metric else metric: "OOM"})
296+
else:
289297
params.update({metric.__name__ if is_callable_metric else metric: cv_agg_func(cv_scores)})
298+
else:
299+
model = tabular_model_t.prepare_model(
300+
datamodule=datamodule,
301+
**prep_model_kwargs,
302+
)
303+
train_kwargs.pop("handle_oom", None)
304+
with OutOfMemoryHandler(handle_oom=True) as handler:
305+
tabular_model_t.train(model=model, datamodule=datamodule, handle_oom=False, **train_kwargs)
306+
if handler.oom_triggered:
307+
if not ignore_oom:
308+
raise OOMException(
309+
"Out of memory error occurred during training. " "Set ignore_oom=True to ignore this error."
310+
)
311+
else:
312+
params.update({metric.__name__ if is_callable_metric else metric: "OOM"})
290313
else:
291-
model = tabular_model_t.prepare_model(
292-
datamodule=datamodule,
293-
**prep_model_kwargs,
294-
)
295-
tabular_model_t.train(model=model, datamodule=datamodule, **train_kwargs)
296314
if is_callable_metric:
297315
preds = tabular_model_t.predict(validation, include_input_features=False)
298316
params.update({metric.__name__: metric(validation[tabular_model_t.config.target], preds)})
299317
else:
300318
result = tabular_model_t.evaluate(validation, verbose=False)
301319
params.update({k.replace("test_", ""): v for k, v in result[0].items()})
302-
params.update({"trial_id": i})
303-
trials.append(params)
304-
if verbose:
305-
logger.info(f"Trial {i+1}/{n_trials}: {params} | Score: {params[metric]}")
306-
if progress:
307-
progress.update(task, advance=1)
320+
params.update({"trial_id": i})
321+
trials.append(params)
322+
if verbose:
323+
logger.info(f"Trial {i+1}/{n_trials}: {params} | Score: {params[metric]}")
308324
trials_df = pd.DataFrame(trials)
309325
trials = trials_df.pop("trial_id")
310326
if mode == "max":

src/pytorch_tabular/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .data_utils import get_balanced_sampler, get_class_weighted_cross_entropy, get_gaussian_centers
22
from .logger import get_logger
33
from .nn_utils import (
4+
OOMException,
5+
OutOfMemoryHandler,
46
_initialize_kaiming,
57
_initialize_layers,
68
_linear_dropout_bn,
@@ -26,4 +28,6 @@
2628
"to_one_hot",
2729
"_initialize_kaiming",
2830
"check_numpy",
31+
"OutOfMemoryHandler",
32+
"OOMException",
2933
]

0 commit comments

Comments
 (0)