diff --git a/econml/policy/_drlearner.py b/econml/policy/_drlearner.py index 8f9ef7082..01a6b0947 100644 --- a/econml/policy/_drlearner.py +++ b/econml/policy/_drlearner.py @@ -7,33 +7,50 @@ from ..utilities import filter_none_kwargs, check_input_arrays from ..dr import DRLearner from ..dr._drlearner import _ModelFinal +from ..inference import GenericModelFinalInferenceDiscrete +from ..grf import RegressionForest from ._base import PolicyLearner from . import PolicyTree, PolicyForest class _PolicyModelFinal(_ModelFinal): + def __init__(self, model_final, featurizer, multitask_model_final, cate_model=None): + super().__init__(model_final, featurizer, multitask_model_final) + self._cate_model = cate_model + def fit(self, Y, T, X=None, W=None, *, nuisances, sample_weight=None, freq_weight=None, sample_var=None, groups=None): if sample_var is not None: warn('Parameter `sample_var` is ignored by the final estimator') sample_var = None - Y_pred, _, _ = nuisances + Y_pred = nuisances[0] self.d_y = Y_pred.shape[1:-1] # track whether there's a Y dimension (must be a singleton) + self.d_t = Y_pred.shape[-1] - 1 if (X is not None) and (self._featurizer is not None): X = self._featurizer.fit_transform(X) filtered_kwargs = filter_none_kwargs(sample_weight=sample_weight, sample_var=sample_var) ys = Y_pred[..., 1:] - Y_pred[..., [0]] # subtract control results from each other arm if self.d_y: # need to squeeze out singleton so that we fit on 2D array ys = ys.squeeze(1) - ys = np.hstack([np.zeros((ys.shape[0], 1)), ys]) - self.model_cate = self._model_final.fit(X, ys, **filtered_kwargs) + ys_with_control = np.hstack([np.zeros((ys.shape[0], 1)), ys]) + self.model_cate = self._model_final.fit(X, ys_with_control, **filtered_kwargs) + + # Also fit per-treatment CATE models for inference support + if self._cate_model is not None: + self.models_cate = [clone(self._cate_model, safe=False).fit(X, ys[..., t], **filtered_kwargs) + for t in range(self.d_t)] + return self def predict(self, X=None): if (X is not None) and (self._featurizer is not None): X = self._featurizer.transform(X) + # Use per-treatment CATE models for prediction when available (supports predict_interval) + if hasattr(self, 'models_cate') and self.models_cate is not None: + preds = np.array([mdl.predict(X).reshape((-1,) + self.d_y) for mdl in self.models_cate]) + return np.moveaxis(preds, 0, -1) # move treatment dim to end pred = self.model_cate.predict_value(X)[:, 1:] if self.d_y: # need to reintroduce singleton Y dimension return pred[:, np.newaxis, :] @@ -45,8 +62,19 @@ def score(self, Y, T, X=None, W=None, *, nuisances, sample_weight=None, groups=N class _DRLearnerWrapper(DRLearner): + def __init__(self, *args, cate_model=None, **kwargs): + super().__init__(*args, **kwargs) + self._cate_model = cate_model + def _gen_ortho_learner_model_final(self): - return _PolicyModelFinal(self._gen_model_final(), self._gen_featurizer(), self.multitask_model_final) + return _PolicyModelFinal(self._gen_model_final(), self._gen_featurizer(), + self.multitask_model_final, cate_model=self._cate_model) + + def _get_inference_options(self): + options = super()._get_inference_options() + if self._cate_model is not None: + options.update(auto=GenericModelFinalInferenceDiscrete) + return options class _BaseDRPolicyLearner(PolicyLearner): @@ -54,7 +82,7 @@ class _BaseDRPolicyLearner(PolicyLearner): def _gen_drpolicy_learner(self): pass - def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None): + def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None, inference='auto'): """ Estimate a policy model from data. @@ -74,13 +102,18 @@ def fit(self, Y, T, *, X=None, W=None, sample_weight=None, groups=None): All rows corresponding to the same group will be kept together during splitting. If groups is not None, the `cv` argument passed to this class's initializer must support a 'groups' argument to its split method. + inference: str or :class:`.Inference` instance, optional + Method for performing inference. All estimators support ``'bootstrap'`` + (or an instance of :class:`.BootstrapInference`). The default is ``'auto'``, + which uses the built-in inference method of the underlying DRLearner. Returns ------- self: object instance """ self.drlearner_ = self._gen_drpolicy_learner() - self.drlearner_.fit(Y, T, X=X, W=W, sample_weight=sample_weight, groups=groups) + self.drlearner_.fit(Y, T, X=X, W=W, sample_weight=sample_weight, groups=groups, + inference=inference) return self def predict_value(self, X): @@ -99,6 +132,245 @@ def predict_value(self, X): """ return self.drlearner_.const_marginal_effect(X) + # ── CATE estimation and inference (delegated to internal DRLearner) ── + + def effect(self, X=None, *, T0=0, T1=1): + """Calculate the heterogeneous treatment effect τ(X) = E[Y(T1) - Y(T0) | X]. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + + Returns + ------- + effect : array_like of shape (n_samples, n_outcomes) + The heterogeneous treatment effect for each sample. + """ + return self.drlearner_.effect(X, T0=T0, T1=T1) + + def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.05): + """Get confidence interval for the heterogeneous treatment effect τ(X). + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + alpha : float, default 0.05 + The significance level. The confidence interval is (1 - alpha)%. + + Returns + ------- + lower, upper : tuple of array_like of shape (n_samples, n_outcomes) + Lower and upper bounds of the confidence interval. + """ + return self.drlearner_.effect_interval(X, T0=T0, T1=T1, alpha=alpha) + + def effect_inference(self, X=None, *, T0=0, T1=1): + """Get inference results for the heterogeneous treatment effect τ(X). + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + + Returns + ------- + inference_results : :class:`~econml.inference.NormalInferenceResults` + Inference results including point estimates, confidence intervals, and p-values. + """ + return self.drlearner_.effect_inference(X, T0=T0, T1=T1) + + def const_marginal_effect(self, X=None): + """Calculate the constant marginal CATE θ(X) for each non-baseline treatment. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + + Returns + ------- + theta : array_like of shape (n_samples, n_treatments - 1) + The constant marginal effect for each sample and treatment. + """ + return self.drlearner_.const_marginal_effect(X) + + def const_marginal_effect_interval(self, X=None, *, alpha=0.05): + """Get confidence interval for the constant marginal CATE θ(X). + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + alpha : float, default 0.05 + The significance level. The confidence interval is (1 - alpha)%. + + Returns + ------- + lower, upper : tuple of array_like + Lower and upper bounds of the confidence interval. + """ + return self.drlearner_.const_marginal_effect_interval(X, alpha=alpha) + + def const_marginal_effect_inference(self, X=None): + """Get inference results for the constant marginal CATE θ(X). + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + + Returns + ------- + inference_results : :class:`~econml.inference.NormalInferenceResults` + Inference results including point estimates, confidence intervals, and p-values. + """ + return self.drlearner_.const_marginal_effect_inference(X) + + def marginal_effect(self, T, X=None): + """Calculate the heterogeneous marginal effect ∂τ(T, X). + + Parameters + ---------- + T : array_like of shape (n_samples,) + Treatment values at which to evaluate the marginal effect. + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + + Returns + ------- + marginal_effect : array_like + The marginal effect for each sample. + """ + return self.drlearner_.marginal_effect(T, X) + + def ate(self, X=None, *, T0=0, T1=1): + """Calculate the average treatment effect E_X[τ(X)]. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + + Returns + ------- + ate : scalar or array_like + The average treatment effect. + """ + return self.drlearner_.ate(X, T0=T0, T1=T1) + + def ate_interval(self, X=None, *, T0=0, T1=1, alpha=0.05): + """Get confidence interval for the average treatment effect. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + alpha : float, default 0.05 + The significance level. + + Returns + ------- + lower, upper : tuple of scalars or array_like + Lower and upper bounds of the confidence interval. + """ + return self.drlearner_.ate_interval(X, T0=T0, T1=T1, alpha=alpha) + + def ate_inference(self, X=None, *, T0=0, T1=1): + """Get inference results for the average treatment effect. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + T0 : int or array_like, default 0 + Baseline treatment. + T1 : int or array_like, default 1 + Target treatment. + + Returns + ------- + inference_results : :class:`~econml.inference.NormalInferenceResults` + Inference results including point estimates, confidence intervals, and p-values. + """ + return self.drlearner_.ate_inference(X, T0=T0, T1=T1) + + def shap_values(self, X, *, feature_names=None, treatment_names=None, + output_names=None, background_samples=100): + """Get SHAP values for the CATE model. + + Parameters + ---------- + X : array_like of shape (n_samples, n_features) + Features for each sample. + feature_names : list of str, optional + The names of the input features. + treatment_names : list of str, optional + The names of the treatments. + output_names : list of str, optional + The names of the outputs. + background_samples : int, default 100 + Number of background samples for SHAP. + + Returns + ------- + shap_values : object + SHAP values for the CATE model. + """ + return self.drlearner_.shap_values(X, feature_names=feature_names, + treatment_names=treatment_names, + output_names=output_names, + background_samples=background_samples) + + def score(self, Y, T, X=None, W=None, *, sample_weight=None): + """Score the fitted CATE model on new data. + + Parameters + ---------- + Y : array_like of shape (n_samples,) + Outcomes for each sample. + T : array_like of shape (n_samples,) + Treatments for each sample. + X : array_like of shape (n_samples, n_features), optional + Features for each sample. + W : array_like of shape (n_samples, n_controls), optional + Controls for each sample. + sample_weight : array_like of shape (n_samples,), optional + Weights for each sample. + + Returns + ------- + score : float + The score of the CATE model. + """ + return self.drlearner_.score(Y, T, X=X, W=W, sample_weight=sample_weight) + + @property + def model_final_(self): + """The fitted final model of the underlying DRLearner.""" + return self.drlearner_.model_final_ + def predict_proba(self, X): """Predict the probability of recommending each treatment. @@ -436,6 +708,10 @@ def _gen_drpolicy_learner(self): honest=self.honest, random_state=self.random_state), multitask_model_final=True, + cate_model=RegressionForest( + min_samples_leaf=self.min_samples_leaf, + honest=self.honest, + random_state=self.random_state), random_state=self.random_state) def plot(self, *, feature_names=None, treatment_names=None, ax=None, title=None, @@ -868,6 +1144,12 @@ def _gen_drpolicy_learner(self): verbose=self.verbose, random_state=self.random_state), multitask_model_final=True, + cate_model=RegressionForest( + n_estimators=max(4, 4 * (self.n_estimators // 4)), + min_samples_leaf=self.min_samples_leaf, + honest=self.honest, + n_jobs=self.n_jobs, + random_state=self.random_state), random_state=self.random_state) def plot(self, tree_id, *, feature_names=None, treatment_names=None,