@@ -1201,65 +1201,19 @@ def evaluate(
12011201 )
12021202 return result
12031203
1204- def predict (
1204+ def _generate_predictions (
12051205 self ,
1206- test : DataFrame ,
1207- quantiles : Optional [List ] = [0.25 , 0.5 , 0.75 ],
1208- n_samples : Optional [int ] = 100 ,
1209- ret_logits = False ,
1210- include_input_features : bool = False ,
1211- device : Optional [torch .device ] = None ,
1212- progress_bar : Optional [str ] = None ,
1213- ) -> DataFrame :
1214- """Uses the trained model to predict on new data and return as a dataframe.
1215-
1216- Args:
1217- test (DataFrame): The new dataframe with the features defined during training
1218- quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
1219- the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
1220- For other models it is ignored. Defaults to [0.25, 0.5, 0.75]
1221- n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
1222- Ignored for non-probabilistic models. Defaults to 100
1223- ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
1224- with the dataframe. Defaults to False
1225- include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
1226- Defaults to True
1227- progress_bar: chose progress bar for tracking the progress
1228-
1229- Returns:
1230- DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
1231- If classification, it returns probabilities and final prediction
1232- """
1233- warnings .warn (
1234- "`include_input_features` will be deprecated in the next release."
1235- " Please add index columns to the test dataframe if you want to"
1236- " retain some features like the key or id" ,
1237- DeprecationWarning ,
1238- )
1239- assert all (q <= 1 and q >= 0 for q in quantiles ), "Quantiles should be a decimal between 0 and 1"
1240- model = self .model # default
1241- if device is not None :
1242- if isinstance (device , str ):
1243- device = torch .device (device )
1244- if self .model .device != device :
1245- model = self .model .to (device )
1246- model .eval ()
1247- inference_dataloader = self .datamodule .prepare_inference_dataloader (test )
1206+ model ,
1207+ inference_dataloader ,
1208+ quantiles ,
1209+ n_samples ,
1210+ ret_logits ,
1211+ progress_bar ,
1212+ is_probabilistic ,
1213+ ):
12481214 point_predictions = []
12491215 quantile_predictions = []
12501216 logits_predictions = defaultdict (list )
1251- is_probabilistic = hasattr (model .hparams , "_probabilistic" ) and model .hparams ._probabilistic
1252-
1253- if progress_bar == "rich" :
1254- from rich .progress import track
1255-
1256- progress_bar = partial (track , description = "Generating Predictions..." )
1257- elif progress_bar == "tqdm" :
1258- from tqdm .auto import tqdm
1259-
1260- progress_bar = partial (tqdm , description = "Generating Predictions..." )
1261- else :
1262- progress_bar = lambda it : it # noqa E731
12631217 for batch in progress_bar (inference_dataloader ):
12641218 for k , v in batch .items ():
12651219 if isinstance (v , list ) and (len (v ) == 0 ):
@@ -1275,8 +1229,6 @@ def predict(
12751229 y_hat , ret_value = model .predict (batch , ret_model_output = True )
12761230 if ret_logits :
12771231 for k , v in ret_value .items ():
1278- # if k == "backbone_features":
1279- # continue
12801232 logits_predictions [k ].append (v .detach ().cpu ())
12811233 point_predictions .append (y_hat .detach ().cpu ())
12821234 if is_probabilistic :
@@ -1288,6 +1240,19 @@ def predict(
12881240 quantile_predictions = torch .cat (quantile_predictions , dim = 0 ).unsqueeze (- 1 )
12891241 if quantile_predictions .ndim == 2 :
12901242 quantile_predictions = quantile_predictions .unsqueeze (- 1 )
1243+ return point_predictions , quantile_predictions , logits_predictions
1244+
1245+ def _format_predicitons (
1246+ self ,
1247+ test ,
1248+ point_predictions ,
1249+ quantile_predictions ,
1250+ logits_predictions ,
1251+ quantiles ,
1252+ ret_logits ,
1253+ include_input_features ,
1254+ is_probabilistic ,
1255+ ):
12911256 pred_df = test .copy () if include_input_features else DataFrame (index = test .index )
12921257 if self .config .task == "regression" :
12931258 point_predictions = point_predictions .numpy ()
@@ -1340,6 +1305,188 @@ def predict(
13401305 pred_df [f"{ k } " ] = v [:, i ]
13411306 return pred_df
13421307
1308+ def _predict (
1309+ self ,
1310+ test : DataFrame ,
1311+ quantiles : Optional [List ] = [0.25 , 0.5 , 0.75 ],
1312+ n_samples : Optional [int ] = 100 ,
1313+ ret_logits = False ,
1314+ include_input_features : bool = False ,
1315+ device : Optional [torch .device ] = None ,
1316+ progress_bar : Optional [str ] = None ,
1317+ ) -> DataFrame :
1318+ """Uses the trained model to predict on new data and return as a dataframe.
1319+
1320+ Args:
1321+ test (DataFrame): The new dataframe with the features defined during training
1322+ quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
1323+ the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
1324+ For other models it is ignored. Defaults to [0.25, 0.5, 0.75]
1325+ n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
1326+ Ignored for non-probabilistic models. Defaults to 100
1327+ ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
1328+ with the dataframe. Defaults to False
1329+ include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
1330+ Defaults to True
1331+ progress_bar: chose progress bar for tracking the progress
1332+
1333+ Returns:
1334+ DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
1335+ If classification, it returns probabilities and final prediction
1336+ """
1337+ assert all (q <= 1 and q >= 0 for q in quantiles ), "Quantiles should be a decimal between 0 and 1"
1338+ model = self .model # default
1339+ if device is not None :
1340+ if isinstance (device , str ):
1341+ device = torch .device (device )
1342+ if self .model .device != device :
1343+ model = self .model .to (device )
1344+ model .eval ()
1345+ inference_dataloader = self .datamodule .prepare_inference_dataloader (test )
1346+ is_probabilistic = hasattr (model .hparams , "_probabilistic" ) and model .hparams ._probabilistic
1347+
1348+ if progress_bar == "rich" :
1349+ from rich .progress import track
1350+
1351+ progress_bar = partial (track , description = "Generating Predictions..." )
1352+ elif progress_bar == "tqdm" :
1353+ from tqdm .auto import tqdm
1354+
1355+ progress_bar = partial (tqdm , description = "Generating Predictions..." )
1356+ else :
1357+ progress_bar = lambda it : it # noqa E731
1358+ point_predictions , quantile_predictions , logits_predictions = self ._generate_predictions (
1359+ model ,
1360+ inference_dataloader ,
1361+ quantiles ,
1362+ n_samples ,
1363+ ret_logits ,
1364+ progress_bar ,
1365+ is_probabilistic ,
1366+ )
1367+ pred_df = self ._format_predicitons (
1368+ test ,
1369+ point_predictions ,
1370+ quantile_predictions ,
1371+ logits_predictions ,
1372+ quantiles ,
1373+ ret_logits ,
1374+ include_input_features ,
1375+ is_probabilistic ,
1376+ )
1377+ return pred_df
1378+
1379+ def predict (
1380+ self ,
1381+ test : DataFrame ,
1382+ quantiles : Optional [List ] = [0.25 , 0.5 , 0.75 ],
1383+ n_samples : Optional [int ] = 100 ,
1384+ ret_logits = False ,
1385+ include_input_features : bool = False ,
1386+ device : Optional [torch .device ] = None ,
1387+ progress_bar : Optional [str ] = None ,
1388+ test_time_augmentation : Optional [bool ] = False ,
1389+ num_tta : Optional [float ] = 5 ,
1390+ alpha_tta : Optional [float ] = 0.1 ,
1391+ aggregate_tta : Optional [str ] = "mean" ,
1392+ ) -> DataFrame :
1393+ """Uses the trained model to predict on new data and return as a dataframe.
1394+
1395+ Args:
1396+ test (DataFrame): The new dataframe with the features defined during training
1397+
1398+ quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
1399+ the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
1400+ For other models it is ignored. Defaults to [0.25, 0.5, 0.75]
1401+
1402+ n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
1403+ Ignored for non-probabilistic models. Defaults to 100
1404+
1405+ ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
1406+ with the dataframe. Defaults to False
1407+
1408+ include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
1409+ Defaults to True
1410+
1411+ progress_bar: chose progress bar for tracking the progress
1412+
1413+ test_time_augmentation (bool): If True, will use test time augmentation to generate predictions.
1414+ The approach is very similar to what is described [here](https://kozodoi.me/blog/20210908/tta-tabular)
1415+ But, we add noise to the embedded inputs to handle categorical features as well.\
1416+ \\ (x_{aug} = x_{orig} + \a lpha * \\ epsilon\\ ) where \\ (\\ epsilon \\ sim \\ mathcal{N}(0, 1)\\ )
1417+ Defaults to False
1418+ num_tta (float): The number of augumentations to run TTA for. Defaults to 0.0
1419+
1420+ alpha_tta (float): The standard deviation of the gaussian noise to be added to the input features
1421+
1422+ aggregate_tta (Union[str, Callable], optional): The function to be used to aggregate the
1423+ predictions from each augumentation. If str, should be one of "mean", "median", "min", or "max"
1424+ for regression. For classification, the previous options are applied to the confidence
1425+ scores (soft voting) and then converted to final prediction. An additional option
1426+ "hard_voting" is available for classification.
1427+ If callable, should be a function that takes in a list of 2D arrays (num_samples, num_targets)
1428+ and returns a 2D array (num_samples, num_targets). Defaults to "mean".
1429+
1430+
1431+ Returns:
1432+ DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
1433+ If classification, it returns probabilities and final prediction
1434+ """
1435+ warnings .warn (
1436+ "`include_input_features` will be deprecated in the next release."
1437+ " Please add index columns to the test dataframe if you want to"
1438+ " retain some features like the key or id" ,
1439+ DeprecationWarning ,
1440+ )
1441+ if test_time_augmentation :
1442+ assert num_tta > 0 , "num_tta should be greater than 0"
1443+ assert alpha_tta > 0 , "alpha_tta should be greater than 0"
1444+ assert include_input_features is False , "include_input_features cannot be True for TTA."
1445+ if not callable (aggregate_tta ):
1446+ assert aggregate_tta in ["mean" , "median" , "min" , "max" , "hard_voting" ], (
1447+ "aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
1448+ )
1449+ if self .config .task == "regression" :
1450+ assert aggregate_tta != "hard_voting" , "hard_voting is only available for classification"
1451+
1452+ def add_noise (module , input , output ):
1453+ return output + alpha_tta * torch .randn_like (output )
1454+
1455+ # Register the hook to the embedding_layer
1456+ handle = self .model .embedding_layer .register_forward_hook (add_noise )
1457+ pred_l = []
1458+ pred_prob_l = []
1459+ for _ in range (num_tta ):
1460+ pred_df = self ._predict (
1461+ test ,
1462+ quantiles ,
1463+ n_samples ,
1464+ ret_logits ,
1465+ include_input_features = False ,
1466+ device = device ,
1467+ progress_bar = progress_bar ,
1468+ )
1469+ pred_idx = pred_df .index
1470+ if self .config .task == "classification" :
1471+ pred_l .append (pred_df .values [:, - len (self .config .target ) :].astype (int ))
1472+ pred_prob_l .append (pred_df .values [:, : - len (self .config .target )])
1473+ elif self .config .task == "regression" :
1474+ pred_prob_l .append (pred_df .values )
1475+ pred_df = self ._combine_predictions (pred_l , pred_prob_l , pred_idx , aggregate_tta , None )
1476+ # Remove the hook
1477+ handle .remove ()
1478+ else :
1479+ pred_df = self ._predict (
1480+ test ,
1481+ quantiles ,
1482+ n_samples ,
1483+ ret_logits ,
1484+ include_input_features ,
1485+ device ,
1486+ progress_bar ,
1487+ )
1488+ return pred_df
1489+
13431490 def load_best_model (self ) -> None :
13441491 """Loads the best model after training is done."""
13451492 if self .trainer .checkpoint_callback is not None :
@@ -1708,7 +1855,8 @@ def _check_cv(self, cv):
17081855 return StratifiedKFold (cv )
17091856 else :
17101857 return KFold (cv )
1711- elif isinstance (cv , Iterable ):
1858+ elif isinstance (cv , Iterable ) and not isinstance (cv , str ):
1859+ # An iterable yielding (train, test) splits as arrays of indices.
17121860 return cv
17131861 elif isinstance (cv , BaseCrossValidator ):
17141862 return cv
@@ -1800,11 +1948,17 @@ def cross_validate(
18001948 metric = metric if metric .startswith ("test_" ) else "test_" + metric
18011949 elif callable (metric ):
18021950 is_callable_metric = True
1951+
1952+ if isinstance (cv , BaseCrossValidator ):
1953+ it = enumerate (cv .split (train , y = train [self .config .target ], groups = groups ))
1954+ else :
1955+ # when iterable is directly passed
1956+ it = enumerate (cv )
18031957 cv_metrics = []
18041958 datamodule = None
18051959 model = None
18061960 oof_preds = []
1807- for fold , (train_idx , val_idx ) in enumerate ( cv . split ( train , y = train [ self . config . target ], groups = groups )) :
1961+ for fold , (train_idx , val_idx ) in it :
18081962 if verbose :
18091963 logger .info (f"Running Fold { fold + 1 } /{ cv .get_n_splits ()} " )
18101964 train_fold = train .iloc [train_idx ]
0 commit comments