@@ -1424,9 +1424,8 @@ def predict(
14241424 for regression. For classification, the previous options are applied to the confidence
14251425 scores (soft voting) and then converted to final prediction. An additional option
14261426 "hard_voting" is available for classification.
1427- If callable, should be a function that takes in a list of 2D arrays (num_samples, num_targets)
1428- and returns a 2D array (num_samples, num_targets). Defaults to "mean".
1429-
1427+ If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
1428+ and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".
14301429
14311430 Returns:
14321431 DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
@@ -1454,7 +1453,6 @@ def add_noise(module, input, output):
14541453
14551454 # Register the hook to the embedding_layer
14561455 handle = self .model .embedding_layer .register_forward_hook (add_noise )
1457- pred_l = []
14581456 pred_prob_l = []
14591457 for _ in range (num_tta ):
14601458 pred_df = self ._predict (
@@ -1468,11 +1466,10 @@ def add_noise(module, input, output):
14681466 )
14691467 pred_idx = pred_df .index
14701468 if self .config .task == "classification" :
1471- pred_l .append (pred_df .values [:, - len (self .config .target ) :].astype (int ))
14721469 pred_prob_l .append (pred_df .values [:, : - len (self .config .target )])
14731470 elif self .config .task == "regression" :
14741471 pred_prob_l .append (pred_df .values )
1475- pred_df = self ._combine_predictions (pred_l , pred_prob_l , pred_idx , aggregate_tta , None )
1472+ pred_df = self ._combine_predictions (pred_prob_l , pred_idx , aggregate_tta , None )
14761473 # Remove the hook
14771474 handle .remove ()
14781475 else :
@@ -1993,7 +1990,6 @@ def cross_validate(
19931990
19941991 def _combine_predictions (
19951992 self ,
1996- pred_l : List [DataFrame ],
19971993 pred_prob_l : List [DataFrame ],
19981994 pred_idx : Union [pd .Index , List ],
19991995 aggregate : Union [str , Callable ],
@@ -2008,15 +2004,16 @@ def _combine_predictions(
20082004 elif aggregate == "max" :
20092005 bagged_pred = np .max (pred_prob_l , axis = 0 )
20102006 elif aggregate == "hard_voting" and self .config .task == "classification" :
2007+ pred_l = [np .argmax (p , axis = 1 ) for p in pred_prob_l ]
20112008 final_pred = np .apply_along_axis (
20122009 lambda x : np .argmax (np .bincount (x )),
20132010 axis = 0 ,
2014- arr = [ p [:, - 1 ]. astype ( int ) for p in pred_l ] ,
2011+ arr = pred_l ,
20152012 )
20162013 elif callable (aggregate ):
2017- final_pred = bagged_pred = aggregate (pred_prob_l )
2014+ bagged_pred = aggregate (pred_prob_l )
20182015 if self .config .task == "classification" :
2019- if aggregate == "hard_voting" or callable ( aggregate ) :
2016+ if aggregate == "hard_voting" :
20202017 pred_df = pd .DataFrame (
20212018 np .concatenate (pred_prob_l , axis = 1 ),
20222019 columns = [
@@ -2094,8 +2091,8 @@ def bagging_predict(
20942091 for regression. For classification, the previous options are applied to the confidence
20952092 scores (soft voting) and then converted to final prediction. An additional option
20962093 "hard_voting" is available for classification.
2097- If callable, should be a function that takes in a list of 2D arrays (num_samples, num_targets)
2098- and returns a 2D array (num_samples, num_targets). Defaults to "mean".
2094+ If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv , num_targets)
2095+ and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".
20992096
21002097 weights (Optional[List[float]], optional): The weights to be used for aggregating the predictions
21012098 from each fold. If None, will use equal weights. This is only used when `aggregate` is "mean".
@@ -2122,7 +2119,6 @@ def bagging_predict(
21222119 assert aggregate != "hard_voting" , "hard_voting is only available for classification"
21232120 cv = self ._check_cv (cv )
21242121 prep_dl_kwargs , prep_model_kwargs , train_kwargs = self ._split_kwargs (kwargs )
2125- pred_l = []
21262122 pred_prob_l = []
21272123 datamodule = None
21282124 model = None
@@ -2149,15 +2145,14 @@ def bagging_predict(
21492145 fold_preds = self .predict (test , include_input_features = False )
21502146 pred_idx = fold_preds .index
21512147 if self .config .task == "classification" :
2152- pred_l .append (fold_preds .values [:, - len (self .config .target ) :].astype (int ))
21532148 pred_prob_l .append (fold_preds .values [:, : - len (self .config .target )])
21542149 elif self .config .task == "regression" :
21552150 pred_prob_l .append (fold_preds .values )
21562151 if verbose :
21572152 logger .info (f"Fold { fold + 1 } /{ cv .get_n_splits ()} prediction done" )
21582153 self .model .reset_weights ()
2159- pred_df = self ._combine_predictions (pred_l , pred_prob_l , pred_idx , aggregate , weights )
2154+ pred_df = self ._combine_predictions (pred_prob_l , pred_idx , aggregate , weights )
21602155 if return_raw_predictions :
2161- return pred_df , pred_l , pred_prob_l
2156+ return pred_df , pred_prob_l
21622157 else :
21632158 return pred_df
0 commit comments