5050 PreEncoded1dLayer ,
5151)
5252from pytorch_tabular .tabular_datamodule import TabularDatamodule
53- from pytorch_tabular .utils import get_logger , getattr_nested , pl_load
53+ from pytorch_tabular .utils import (
54+ OOMException ,
55+ OutOfMemoryHandler ,
56+ get_logger ,
57+ getattr_nested ,
58+ pl_load ,
59+ )
5460
5561try :
5662 import captum .attr
@@ -574,6 +580,7 @@ def train(
574580 callbacks : Optional [List [pl .Callback ]] = None ,
575581 max_epochs : int = None ,
576582 min_epochs : int = None ,
583+ handle_oom : bool = True ,
577584 ) -> pl .Trainer :
578585 """Trains the model.
579586
@@ -589,6 +596,8 @@ def train(
589596
590597 min_epochs (Optional[int]): Overwrite minimum number of epochs to be run. Defaults to None.
591598
599+ handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.
600+
592601 Returns:
593602 pl.Trainer: The PyTorch Lightning Trainer instance
594603 """
@@ -601,18 +610,36 @@ def train(
601610 if self .config .auto_lr_find and (not self .config .fast_dev_run ):
602611 if self .verbose :
603612 logger .info ("Auto LR Find Started" )
604- result = Tuner (self .trainer ).lr_find (self .model , train_dataloaders = train_loader , val_dataloaders = val_loader )
613+ with OutOfMemoryHandler (handle_oom = handle_oom ) as oom_handler :
614+ result = Tuner (self .trainer ).lr_find (
615+ self .model ,
616+ train_dataloaders = train_loader ,
617+ val_dataloaders = val_loader ,
618+ )
619+ if oom_handler .oom_triggered :
620+ raise OOMException (
621+ "OOM detected during LR Find. Try reducing your batch_size or the"
622+ " model parameters." + "/n" + "Original Error: " + oom_handler .oom_msg
623+ )
605624 if self .verbose :
606625 logger .info (
607626 f"Suggested LR: { result .suggestion ()} . For plot and detailed"
608627 " analysis, use `find_learning_rate` method."
609628 )
629+ self .model .reset_weights ()
610630 # Parameters in models needs to be initialized again after LR find
611631 self .model .data_aware_initialization (self .datamodule )
612632 self .model .train ()
613633 if self .verbose :
614634 logger .info ("Training Started" )
615- self .trainer .fit (self .model , train_loader , val_loader )
635+ with OutOfMemoryHandler (handle_oom = handle_oom ) as oom_handler :
636+ self .trainer .fit (self .model , train_loader , val_loader )
637+ if oom_handler .oom_triggered :
638+ raise OOMException (
639+ "OOM detected during Training. Try reducing your batch_size or the"
640+ " model parameters."
641+ "/n" + "Original Error: " + oom_handler .oom_msg
642+ )
616643 self ._is_fitted = True
617644 if self .verbose :
618645 logger .info ("Training the model completed" )
@@ -637,6 +664,7 @@ def fit(
637664 callbacks : Optional [List [pl .Callback ]] = None ,
638665 datamodule : Optional [TabularDatamodule ] = None ,
639666 cache_data : str = "memory" ,
667+ handle_oom : bool = True ,
640668 ) -> pl .Trainer :
641669 """The fit method which takes in the data and triggers the training.
642670
@@ -690,6 +718,8 @@ def fit(
690718 cache_data (str): Decides how to cache the data in the dataloader. If set to
691719 "memory", will cache in memory. If set to a valid path, will cache in that path. Defaults to "memory".
692720
721+ handle_oom (bool): If True, will try to handle OOM errors elegantly. Defaults to True.
722+
693723 Returns:
694724 pl.Trainer: The PyTorch Lightning Trainer instance
695725 """
@@ -728,7 +758,7 @@ def fit(
728758 optimizer_params or {},
729759 )
730760
731- return self .train (model , datamodule , callbacks , max_epochs , min_epochs )
761+ return self .train (model , datamodule , callbacks , max_epochs , min_epochs , handle_oom )
732762
733763 def pretrain (
734764 self ,
@@ -1229,7 +1259,7 @@ def predict(
12291259
12301260 progress_bar = partial (tqdm , description = "Generating Predictions..." )
12311261 else :
1232- progress_bar = lambda it : it
1262+ progress_bar = lambda it : it # noqa E731
12331263 for batch in progress_bar (inference_dataloader ):
12341264 for k , v in batch .items ():
12351265 if isinstance (v , list ) and (len (v ) == 0 ):
@@ -1293,8 +1323,9 @@ def predict(
12931323 np .argmax (point_predictions , axis = 1 )
12941324 )
12951325 warnings .warn (
1296- "Classification prediction column will be renamed to `{target_col}_prediction` "
1297- "in the next release to maintain consistency with regression." ,
1326+ "Classification prediction column will be renamed to"
1327+ " `{target_col}_prediction` in the next release to maintain"
1328+ " consistency with regression." ,
12981329 DeprecationWarning ,
12991330 )
13001331 if ret_logits :
@@ -1710,6 +1741,7 @@ def cross_validate(
17101741 groups : Optional [Union [str , np .ndarray ]] = None ,
17111742 verbose : bool = True ,
17121743 reset_datamodule : bool = True ,
1744+ handle_oom : bool = True ,
17131745 ** kwargs ,
17141746 ):
17151747 """Cross validate the model.
@@ -1753,6 +1785,7 @@ def cross_validate(
17531785 If False, we take an approximation that once the transformations are fit on the first
17541786 fold, they will be valid for all the other folds. Defaults to True.
17551787
1788+ handle_oom (bool, optional): If True, will handle out of memory errors elegantly
17561789 **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.
17571790
17581791 Returns:
@@ -1789,7 +1822,8 @@ def cross_validate(
17891822 datamodule .validation , _ = datamodule .preprocess_data (val_fold , stage = "inference" )
17901823
17911824 # Train the model
1792- self .train (model , datamodule , ** train_kwargs )
1825+ handle_oom = train_kwargs .pop ("handle_oom" , handle_oom )
1826+ self .train (model , datamodule , handle_oom = handle_oom , ** train_kwargs )
17931827 if return_oof or is_callable_metric :
17941828 preds = self .predict (val_fold , include_input_features = False )
17951829 oof_preds .append (preds )
@@ -1864,6 +1898,7 @@ def bagging_predict(
18641898 return_raw_predictions : bool = False ,
18651899 aggregate : Union [str , Callable ] = "mean" ,
18661900 weights : Optional [List [float ]] = None ,
1901+ handle_oom : bool = True ,
18671902 ** kwargs ,
18681903 ):
18691904 """Bagging predict on the test data.
@@ -1912,6 +1947,8 @@ def bagging_predict(
19121947 from each fold. If None, will use equal weights. This is only used when `aggregate` is "mean".
19131948 Defaults to None.
19141949
1950+ handle_oom (bool, optional): If True, will handle out of memory errors elegantly
1951+
19151952 **kwargs: Additional keyword arguments to be passed to the `fit` method of the model.
19161953
19171954 Returns:
@@ -1953,7 +1990,8 @@ def bagging_predict(
19531990 datamodule .validation , _ = datamodule .preprocess_data (val_fold , stage = "inference" )
19541991
19551992 # Train the model
1956- self .train (model , datamodule , ** train_kwargs )
1993+ handle_oom = train_kwargs .pop ("handle_oom" , handle_oom )
1994+ self .train (model , datamodule , handle_oom = handle_oom , ** train_kwargs )
19571995 fold_preds = self .predict (test , include_input_features = False )
19581996 pred_idx = fold_preds .index
19591997 if self .config .task == "classification" :
0 commit comments