diff --git a/econml/tests/test_drtester.py b/econml/tests/test_drtester.py index 3396f0fe2..32dd404c0 100644 --- a/econml/tests/test_drtester.py +++ b/econml/tests/test_drtester.py @@ -286,3 +286,34 @@ def test_exceptions(self): autoc_res = my_dr_tester.evaluate_uplift(Xval, Xtrain, metric='toc') self.assertLess(autoc_res.pvals[0], 0.05) + + def test_dataframe_input(self): + Xtrain, Dtrain, Ytrain, Xval, Dval, Yval = self._get_data(num_treatments=1) + + reg_t = RandomForestClassifier(random_state=0) + reg_y = GradientBoostingRegressor(random_state=0) + + cate = DML( + model_y=reg_y, + model_t=reg_t, + model_final=reg_y, + discrete_treatment=True + ).fit(Y=Ytrain, T=Dtrain, X=Xtrain) + + # Fit with numpy arrays as baseline + dr_numpy = DRTester( + model_regression=reg_y, + model_propensity=reg_t, + cate=cate + ).fit_nuisance(Xval, Dval, Yval, Xtrain, Dtrain, Ytrain) + + # Fit with DataFrames/Series + dr_pandas = DRTester( + model_regression=reg_y, + model_propensity=reg_t, + cate=cate + ).fit_nuisance( + pd.DataFrame(Xval), pd.Series(Dval), pd.Series(Yval), + pd.DataFrame(Xtrain), pd.Series(Dtrain), pd.Series(Ytrain) + ) + np.testing.assert_array_equal(dr_pandas.dr_val_, dr_numpy.dr_val_) diff --git a/econml/validate/drtester.py b/econml/validate/drtester.py index c79330218..60cf3fb2a 100644 --- a/econml/validate/drtester.py +++ b/econml/validate/drtester.py @@ -8,7 +8,7 @@ from statsmodels.api import OLS from statsmodels.tools import add_constant -from econml.utilities import deprecated +from econml.utilities import check_input_arrays, deprecated from .results import CalibrationEvaluationResults, BLPEvaluationResults, UpliftEvaluationResults, EvaluationResults from .utils import calculate_dr_outcomes, calc_uplift @@ -221,6 +221,9 @@ def fit_nuisance( If training data is provided, also adds attributes for the doubly robust outcomes for the training set (dr_train) and the training treatments (Dtrain) """ + Xval, Dval, yval = check_input_arrays(Xval, Dval, yval) + Xtrain, Dtrain, ytrain = check_input_arrays(Xtrain, Dtrain, ytrain) + self.Dval = Dval # Unique treatments (ordered, includes control)