Skip to content

Commit 9b57ce8

Browse files
committed
Update v0.0.5
1 parent 7733206 commit 9b57ce8

File tree

5 files changed

+93
-48
lines changed

5 files changed

+93
-48
lines changed

causarray/DR_estimation.py

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import numpy as np
22
from sklearn.linear_model import LogisticRegression
3-
from sklearn.ensemble import RandomForestClassifier
3+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
4+
from sklearn_ensemble_cv import reset_random_seeds, Ensemble, ECV
45
from causarray.gcate_glm import fit_glm
56
from causarray.utils import *
67
from causarray.utils import _filter_params
8+
from joblib import Parallel, delayed
9+
from tqdm import tqdm
710
import pprint
811

912
from sklearn.model_selection import KFold, ShuffleSplit
@@ -82,10 +85,15 @@ def cross_fitting(
8285
pprint.pprint(params_ps)
8386
pprint.pprint(params_glm)
8487

85-
if K>1:
86-
# Initialize KFold cross-validator
87-
kf = KFold(n_splits=K, random_state=0, shuffle=True)
88-
folds = kf.split(X)
88+
if K > 1:
89+
n_samples = X.shape[0]
90+
if K >= n_samples:
91+
# Use Leave-One-Out Cross-Validation
92+
folds = [([i for i in range(n_samples) if i != j], [j]) for j in range(n_samples)]
93+
else:
94+
# Initialize KFold cross-validator
95+
kf = KFold(n_splits=int(K), random_state=0, shuffle=True)
96+
folds = kf.split(X)
8997
else:
9098
folds = [(np.arange(X.shape[0]), np.arange(X.shape[0]))]
9199

@@ -95,6 +103,16 @@ def cross_fitting(
95103
fit_Y = True if Y_hat is None else False
96104
Y_hat = np.zeros((Y.shape[0],Y.shape[1],A.shape[1],2), dtype=float) if fit_Y else Y_hat
97105

106+
# perform ECV at once
107+
if fit_pi and ps_model == 'random_forest_cv':
108+
info_ecv = run_ecv(X_A, A, **params_ps)
109+
func_ps, params_ps = _get_func_ps(ps_model, verbose=False, ecv=False,
110+
kwargs_ensemble=info_ecv['best_params_ensemble'], kwargs_regr=info_ecv['best_params_regr'])
111+
pprint.pprint('Best parameters for the regression model:')
112+
pprint.pprint(info_ecv['best_params_regr'])
113+
pprint.pprint('Best parameters for the ensemble model:')
114+
pprint.pprint(info_ecv['best_params_ensemble'])
115+
98116
# Perform cross-fitting
99117
for train_index, test_index in folds:
100118
# Split data
@@ -178,8 +196,6 @@ def AIPW_mean(Y, A, mu, pi, positive=False):
178196
tau = np.mean(pseudo_y, axis=0)
179197

180198
return tau, pseudo_y
181-
182-
183199

184200

185201

@@ -188,51 +204,89 @@ def AIPW_mean(Y, A, mu, pi, positive=False):
188204

189205

190206

191-
from joblib import Parallel, delayed
192-
from tqdm import tqdm
193-
from sklearn_ensemble_cv import reset_random_seeds, Ensemble, ECV
194-
from sklearn.tree import DecisionTreeRegressor
195-
196-
def fit_rf(X, y, X_test=None, sample_weight=None, M=100, M_max=1000,
207+
def run_ecv(
208+
X, y, M=200, M_max=1000,
197209
# fixed parameters for bagging regressor
198-
kwargs_ensemble={'verbose':1},
210+
kwargs_ensemble={},
199211
# fixed parameters for decision tree
200-
kwargs_regr={'min_samples_leaf': 3}, # 'min_samples_split': 10, 'max_features':'sqrt'
212+
kwargs_regr={},
201213
# grid search parameters
202-
grid_regr = {'max_depth': [11]},
203-
grid_ensemble = {'random_state': 0}, #'max_samples':np.linspace(0.25, 1., 4)
204-
):
214+
grid_regr={},
215+
grid_ensemble={}
216+
):
217+
"""
218+
Runs Ensemble Cross-Validation (ECV) to find the best hyperparameters.
219+
"""
220+
kwargs_ensemble = {**{'verbose': 1, 'bootstrap': True}, **kwargs_ensemble}
221+
kwargs_regr = {**{'min_samples_split': 20, 'min_samples_leaf': 10, 'max_features': 'sqrt', 'ccp_alpha': 0.02, 'class_weight': 'balanced'}, **kwargs_regr}
222+
grid_regr = {**{'max_depth': [3, 5, 7]}, **grid_regr}
223+
grid_ensemble = {**{'random_state': 0, 'max_samples': [0.4, 0.6, 0.8, 1.]}, **grid_ensemble}
205224

206225
# Validate integer parameters
207226
M = int(M)
208227
M_max = int(M_max)
209-
# for kwargs in [kwargs_regr, kwargs_ensemble, grid_regr, grid_ensemble]:
210-
# for param in kwargs:
211-
# if param in ['max_depth', 'random_state', 'max_leaf_nodes'] and isinstance(kwargs[param], float):
212-
# kwargs[param] = int(kwargs[param])
213228

214229
# Make sure y is 2D
215230
y = y.reshape(-1, 1) if y.ndim == 1 else y
216231

217232
# Run ECV
218-
res_ecv, info_ecv = ECV(
219-
X, y, DecisionTreeRegressor, grid_regr, grid_ensemble,
220-
kwargs_regr, kwargs_ensemble,
233+
_, info_ecv = ECV(
234+
X, y, DecisionTreeClassifier, grid_regr, grid_ensemble,
235+
kwargs_regr, kwargs_ensemble,
221236
M=M, M0=M, M_max=M_max, return_df=True
222237
)
223238

224239
# Replace the in-sample best parameter for 'n_estimators' with extrapolated best parameter
225240
info_ecv['best_params_ensemble']['n_estimators'] = info_ecv['best_n_estimators_extrapolate']
226241

242+
return info_ecv
243+
244+
245+
def fit_rf(
246+
X, y, X_test=None, M=100, M_max=1000, ecv=True,
247+
# fixed parameters for bagging regressor
248+
kwargs_ensemble={},
249+
# fixed parameters for decision tree
250+
kwargs_regr={},
251+
# grid search parameters
252+
grid_regr={},
253+
grid_ensemble={}
254+
):
255+
"""
256+
Fits a Random Forest model using parameters found by ECV.
257+
"""
258+
259+
kwargs_ensemble = {**{'verbose': 1, 'bootstrap': True}, **kwargs_ensemble}
260+
kwargs_regr = {**{'min_samples_split': 20, 'min_samples_leaf': 10, 'max_features': 'sqrt', 'ccp_alpha': 0.02, 'class_weight': 'balanced'}, **kwargs_regr}
261+
grid_regr = {**{'max_depth': [3, 5, 7]}, **grid_regr}
262+
grid_ensemble = {**{'random_state': 0, 'max_samples': [0.4, 0.6, 0.8, 1.]}, **grid_ensemble}
263+
264+
# Make sure y is 2D
265+
y_2d = y.reshape(-1, 1) if y.ndim == 1 else y
266+
267+
if ecv:
268+
# Get best parameters from ECV
269+
info_ecv = run_ecv(
270+
X, y_2d, M=M, M_max=M_max,
271+
kwargs_ensemble=kwargs_ensemble,
272+
kwargs_regr=kwargs_regr,
273+
grid_regr=grid_regr,
274+
grid_ensemble=grid_ensemble
275+
)
276+
params_regr = info_ecv['best_params_regr']
277+
params_ensemble = info_ecv['best_params_ensemble']
278+
else:
279+
params_regr = kwargs_regr
280+
params_ensemble = kwargs_ensemble
281+
227282
# Fit the ensemble with the best CV parameters
228283
regr = Ensemble(
229-
estimator=DecisionTreeRegressor(**info_ecv['best_params_regr']),
230-
**info_ecv['best_params_ensemble']).fit(X, y, sample_weight=sample_weight)
231-
284+
estimator=DecisionTreeClassifier(**params_regr), **params_ensemble).fit(X, y_2d)
285+
232286
# Predict
233287
if X_test is None:
234288
X_test = X
235-
return regr.predict(X_test).reshape(-1, y.shape[1])
289+
return regr.predict(X_test).reshape(-1, y_2d.shape[1])
236290

237291

238292

@@ -252,11 +306,7 @@ def fit_rf_ind_ps(X, Y, *args, **kwargs):
252306
def _fit(X, y, i_ctrl, *args, **kwargs):
253307
i_case = (y == 1.)
254308
i_cells = i_ctrl | i_case
255-
sample_weight = np.ones(y.shape[0])
256-
class_weight = len(y) / (2 * np.bincount(y.astype(int)))
257-
for a in range(2):
258-
sample_weight[y == a] = class_weight[a]
259-
return fit_rf(X[i_cells], y[i_cells], sample_weight=sample_weight[i_cells], *args, **kwargs)
309+
return fit_rf(X[i_cells], y[i_cells], *args, **kwargs)
260310

261311
Y_hat = Parallel(n_jobs=-1)(delayed(_fit)(X, Y[:,j], i_ctrl, *args, **kwargs)
262312
for j in tqdm(range(Y.shape[1])))

causarray/DR_learner.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def compute_causal_estimand(
7474
Y = Y.astype('float')
7575
n, p = Y.shape
7676

77-
if A.ndim == 1: A = A[:, None]
77+
if len(A.shape) == 1:
78+
A = A.reshape(-1,1)
7879
if isinstance(A, pd.DataFrame):
7980
trt_names = A.columns
8081
A = A.values
@@ -169,7 +170,7 @@ def compute_causal_estimand(
169170
def LFC(
170171
Y, W, A, W_A=None, family='nb', offset=False,
171172
Y_hat=None, pi_hat=None, cross_est=False, mask=None, usevar='pooled',
172-
thres_min=1e-4, thres_diff=1e-6, eps_var=1e-3,
173+
thres_min=1e-2, thres_diff=1e-2, eps_var=1e-4,
173174
fdx=False, fdx_alpha=0.05, fdx_c=0.1,
174175
verbose=False, **kwargs):
175176
'''
@@ -200,9 +201,6 @@ def LFC(
200201
Boolean mask of shape (n, a) for the treatment, indicating which samples are used for
201202
the estimation of the estimand. This does not affect the estimation of pseudo-outcomes
202203
and propensity scores.
203-
usevar : str
204-
The method to use for estimating the variance of treatment effects.
205-
Options are 'pooled' (default) or 'unequal'.
206204
207205
thres_min : float
208206
The minimum threshold for the treatment effect.
@@ -246,12 +244,12 @@ def estimand(etas, A, **kwargs):
246244
var_1 = np.var(eta_est[A==1], axis=0, ddof=1)
247245
n_0 = np.sum(A==0)
248246
n_1 = np.sum(A==1)
249-
var_est = (var_0 + eps_var) / n_0 + (var_1 + eps_var) / n_1
247+
var_est = ((var_0 + eps_var) / n_0 + (var_1 + eps_var) / n_1) / 2
250248
else:
251249
raise ValueError('usevar must be either "pooled" or "unequal"')
252250

253251
# filter out low-expressed genes
254-
idx = (np.maximum(tau_0,tau_1)<thres_min) & ((tau_1-tau_0)<thres_diff)
252+
idx = (np.maximum(np.abs(tau_0),np.abs(tau_1))<thres_min) | (np.abs(tau_1-tau_0)<thres_diff)
255253
tau_est[idx] = 0.; eta_est[:,idx] = 0.; var_est[idx] = np.inf
256254

257255
return eta_est, tau_est, var_est

causarray/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.5"
1+
__version__ = "0.0.4"

causarray/gcate.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def fit_gcate(Y, X, A, r, family='nb', disp_glm=None, disp_family=None, offset=T
8282
kwargs : dict
8383
Additional keyword arguments.
8484
'''
85-
if X.ndim == 1: X = X[:, None]
86-
if A.ndim == 1: A = A[:, None]
85+
8786
X = np.hstack((X, A))
8887
a = A.shape[1]
8988
Y, kwargs_glm, lam1 = _check_input(Y, X, family, disp_glm, disp_family, offset, c1, **kwargs)
@@ -196,10 +195,8 @@ def estimate_r(Y, X, A, r_max, c=1.,
196195
df_r : DataFrame
197196
Results of the number of latent factors.
198197
'''
199-
if X.ndim == 1: X = X[:, None]
200-
if A.ndim == 1: A = A[:, None]
201198
a, d = A.shape[1], X.shape[1]
202-
X = np.hstack((X, A))
199+
X = np.hstack((X, A))
203200
n, p = Y.shape
204201

205202
Y, kwargs_glm, _ = _check_input(Y, X, family, disp_glm, disp_family, offset, None, **kwargs)

causarray/gcate_opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def alter_min(
261261
kwargs_ls['alpha'] = kwargs_ls['alpha']
262262
if verbose:
263263
pprint.pprint({'kwargs_glm':kwargs_glm,'kwargs_ls':kwargs_ls,'kwargs_es':kwargs_es}, compact=True)
264-
pprint.pprint(f'Fitting GCATE (step {1 if P1 is None else 2})...')
264+
pprint.pprint(f'Fitting GCATE (step {2 if P1 is None else 1})...')
265265
hist = [func_val_pre]
266266
es = Early_Stopping(**kwargs_es)
267267
with tqdm(np.arange(kwargs_es['max_iters']), disable=not verbose) as pbar:

0 commit comments

Comments
 (0)