11import numpy as np
22from sklearn .linear_model import LogisticRegression
3- from sklearn .ensemble import RandomForestClassifier
3+ from sklearn .tree import DecisionTreeClassifier , DecisionTreeRegressor
4+ from sklearn_ensemble_cv import reset_random_seeds , Ensemble , ECV
45from causarray .gcate_glm import fit_glm
56from causarray .utils import *
67from causarray .utils import _filter_params
8+ from joblib import Parallel , delayed
9+ from tqdm import tqdm
710import pprint
811
912from sklearn .model_selection import KFold , ShuffleSplit
@@ -82,10 +85,15 @@ def cross_fitting(
8285 pprint .pprint (params_ps )
8386 pprint .pprint (params_glm )
8487
85- if K > 1 :
86- # Initialize KFold cross-validator
87- kf = KFold (n_splits = K , random_state = 0 , shuffle = True )
88- folds = kf .split (X )
88+ if K > 1 :
89+ n_samples = X .shape [0 ]
90+ if K >= n_samples :
91+ # Use Leave-One-Out Cross-Validation
92+ folds = [([i for i in range (n_samples ) if i != j ], [j ]) for j in range (n_samples )]
93+ else :
94+ # Initialize KFold cross-validator
95+ kf = KFold (n_splits = int (K ), random_state = 0 , shuffle = True )
96+ folds = kf .split (X )
8997 else :
9098 folds = [(np .arange (X .shape [0 ]), np .arange (X .shape [0 ]))]
9199
@@ -95,6 +103,16 @@ def cross_fitting(
95103 fit_Y = True if Y_hat is None else False
96104 Y_hat = np .zeros ((Y .shape [0 ],Y .shape [1 ],A .shape [1 ],2 ), dtype = float ) if fit_Y else Y_hat
97105
106+ # perform ECV at once
107+ if fit_pi and ps_model == 'random_forest_cv' :
108+ info_ecv = run_ecv (X_A , A , ** params_ps )
109+ func_ps , params_ps = _get_func_ps (ps_model , verbose = False , ecv = False ,
110+ kwargs_ensemble = info_ecv ['best_params_ensemble' ], kwargs_regr = info_ecv ['best_params_regr' ])
111+ pprint .pprint ('Best parameters for the regression model:' )
112+ pprint .pprint (info_ecv ['best_params_regr' ])
113+ pprint .pprint ('Best parameters for the ensemble model:' )
114+ pprint .pprint (info_ecv ['best_params_ensemble' ])
115+
98116 # Perform cross-fitting
99117 for train_index , test_index in folds :
100118 # Split data
@@ -178,8 +196,6 @@ def AIPW_mean(Y, A, mu, pi, positive=False):
178196 tau = np .mean (pseudo_y , axis = 0 )
179197
180198 return tau , pseudo_y
181-
182-
183199
184200
185201
@@ -188,51 +204,89 @@ def AIPW_mean(Y, A, mu, pi, positive=False):
188204
189205
190206
191- from joblib import Parallel , delayed
192- from tqdm import tqdm
193- from sklearn_ensemble_cv import reset_random_seeds , Ensemble , ECV
194- from sklearn .tree import DecisionTreeRegressor
195-
196- def fit_rf (X , y , X_test = None , sample_weight = None , M = 100 , M_max = 1000 ,
207+ def run_ecv (
208+ X , y , M = 200 , M_max = 1000 ,
197209 # fixed parameters for bagging regressor
198- kwargs_ensemble = {'verbose' : 1 },
210+ kwargs_ensemble = {},
199211 # fixed parameters for decision tree
200- kwargs_regr = {'min_samples_leaf' : 3 }, # 'min_samples_split': 10, 'max_features':'sqrt'
212+ kwargs_regr = {},
201213 # grid search parameters
202- grid_regr = {'max_depth' : [11 ]},
203- grid_ensemble = {'random_state' : 0 }, #'max_samples':np.linspace(0.25, 1., 4)
204- ):
214+ grid_regr = {},
215+ grid_ensemble = {}
216+ ):
217+ """
218+ Runs Ensemble Cross-Validation (ECV) to find the best hyperparameters.
219+ """
220+ kwargs_ensemble = {** {'verbose' : 1 , 'bootstrap' : True }, ** kwargs_ensemble }
221+ kwargs_regr = {** {'min_samples_split' : 20 , 'min_samples_leaf' : 10 , 'max_features' : 'sqrt' , 'ccp_alpha' : 0.02 , 'class_weight' : 'balanced' }, ** kwargs_regr }
222+ grid_regr = {** {'max_depth' : [3 , 5 , 7 ]}, ** grid_regr }
223+ grid_ensemble = {** {'random_state' : 0 , 'max_samples' : [0.4 , 0.6 , 0.8 , 1. ]}, ** grid_ensemble }
205224
206225 # Validate integer parameters
207226 M = int (M )
208227 M_max = int (M_max )
209- # for kwargs in [kwargs_regr, kwargs_ensemble, grid_regr, grid_ensemble]:
210- # for param in kwargs:
211- # if param in ['max_depth', 'random_state', 'max_leaf_nodes'] and isinstance(kwargs[param], float):
212- # kwargs[param] = int(kwargs[param])
213228
214229 # Make sure y is 2D
215230 y = y .reshape (- 1 , 1 ) if y .ndim == 1 else y
216231
217232 # Run ECV
218- res_ecv , info_ecv = ECV (
219- X , y , DecisionTreeRegressor , grid_regr , grid_ensemble ,
220- kwargs_regr , kwargs_ensemble ,
233+ _ , info_ecv = ECV (
234+ X , y , DecisionTreeClassifier , grid_regr , grid_ensemble ,
235+ kwargs_regr , kwargs_ensemble ,
221236 M = M , M0 = M , M_max = M_max , return_df = True
222237 )
223238
224239 # Replace the in-sample best parameter for 'n_estimators' with extrapolated best parameter
225240 info_ecv ['best_params_ensemble' ]['n_estimators' ] = info_ecv ['best_n_estimators_extrapolate' ]
226241
242+ return info_ecv
243+
244+
245+ def fit_rf (
246+ X , y , X_test = None , M = 100 , M_max = 1000 , ecv = True ,
247+ # fixed parameters for bagging regressor
248+ kwargs_ensemble = {},
249+ # fixed parameters for decision tree
250+ kwargs_regr = {},
251+ # grid search parameters
252+ grid_regr = {},
253+ grid_ensemble = {}
254+ ):
255+ """
256+ Fits a Random Forest model using parameters found by ECV.
257+ """
258+
259+ kwargs_ensemble = {** {'verbose' : 1 , 'bootstrap' : True }, ** kwargs_ensemble }
260+ kwargs_regr = {** {'min_samples_split' : 20 , 'min_samples_leaf' : 10 , 'max_features' : 'sqrt' , 'ccp_alpha' : 0.02 , 'class_weight' : 'balanced' }, ** kwargs_regr }
261+ grid_regr = {** {'max_depth' : [3 , 5 , 7 ]}, ** grid_regr }
262+ grid_ensemble = {** {'random_state' : 0 , 'max_samples' : [0.4 , 0.6 , 0.8 , 1. ]}, ** grid_ensemble }
263+
264+ # Make sure y is 2D
265+ y_2d = y .reshape (- 1 , 1 ) if y .ndim == 1 else y
266+
267+ if ecv :
268+ # Get best parameters from ECV
269+ info_ecv = run_ecv (
270+ X , y_2d , M = M , M_max = M_max ,
271+ kwargs_ensemble = kwargs_ensemble ,
272+ kwargs_regr = kwargs_regr ,
273+ grid_regr = grid_regr ,
274+ grid_ensemble = grid_ensemble
275+ )
276+ params_regr = info_ecv ['best_params_regr' ]
277+ params_ensemble = info_ecv ['best_params_ensemble' ]
278+ else :
279+ params_regr = kwargs_regr
280+ params_ensemble = kwargs_ensemble
281+
227282 # Fit the ensemble with the best CV parameters
228283 regr = Ensemble (
229- estimator = DecisionTreeRegressor (** info_ecv ['best_params_regr' ]),
230- ** info_ecv ['best_params_ensemble' ]).fit (X , y , sample_weight = sample_weight )
231-
284+ estimator = DecisionTreeClassifier (** params_regr ), ** params_ensemble ).fit (X , y_2d )
285+
232286 # Predict
233287 if X_test is None :
234288 X_test = X
235- return regr .predict (X_test ).reshape (- 1 , y .shape [1 ])
289+ return regr .predict (X_test ).reshape (- 1 , y_2d .shape [1 ])
236290
237291
238292
@@ -252,11 +306,7 @@ def fit_rf_ind_ps(X, Y, *args, **kwargs):
252306 def _fit (X , y , i_ctrl , * args , ** kwargs ):
253307 i_case = (y == 1. )
254308 i_cells = i_ctrl | i_case
255- sample_weight = np .ones (y .shape [0 ])
256- class_weight = len (y ) / (2 * np .bincount (y .astype (int )))
257- for a in range (2 ):
258- sample_weight [y == a ] = class_weight [a ]
259- return fit_rf (X [i_cells ], y [i_cells ], sample_weight = sample_weight [i_cells ], * args , ** kwargs )
309+ return fit_rf (X [i_cells ], y [i_cells ], * args , ** kwargs )
260310
261311 Y_hat = Parallel (n_jobs = - 1 )(delayed (_fit )(X , Y [:,j ], i_ctrl , * args , ** kwargs )
262312 for j in tqdm (range (Y .shape [1 ])))
0 commit comments