diff --git a/econml/_lazy.py b/econml/_lazy.py
new file mode 100644
index 000000000..ce151b0c8
--- /dev/null
+++ b/econml/_lazy.py
@@ -0,0 +1,40 @@
+# Copyright (c) PyWhy contributors. All rights reserved.
+# Licensed under the MIT License.
+
+"""Lazy module loading to avoid expensive imports at package load time."""
+
+import importlib
+
+
+class _LazyModule:
+ """Proxy that delays importing a module until an attribute is accessed.
+
+ Use at module level as a drop-in replacement for ``import heavy_lib``::
+
+ heavy_lib = _LazyModule("heavy_lib")
+
+ The real module is imported on first attribute access, so the cost is
+ deferred until the functionality is actually needed.
+ """
+
+ def __init__(self, module_name):
+ object.__setattr__(self, "_module_name", module_name)
+ object.__setattr__(self, "_module", None)
+
+ def _load(self):
+ module = object.__getattribute__(self, "_module")
+ if module is None:
+ name = object.__getattribute__(self, "_module_name")
+ module = importlib.import_module(name)
+ object.__setattr__(self, "_module", module)
+ return module
+
+ def __getattr__(self, name):
+ return getattr(self._load(), name)
+
+ def __repr__(self):
+ module = object.__getattribute__(self, "_module")
+ if module is not None:
+ return repr(module)
+ name = object.__getattribute__(self, "_module_name")
+ return f"<_LazyModule '{name}' (not yet loaded)>"
diff --git a/econml/_ortho_learner.py b/econml/_ortho_learner.py
index cf966b631..e27b47a67 100644
--- a/econml/_ortho_learner.py
+++ b/econml/_ortho_learner.py
@@ -41,6 +41,9 @@ class in this module implements the general logic in a very versatile way
filter_none_kwargs, one_hot_encoder, strata_from_discrete_arrays,
jacify_featurizer, reshape, shape)
from .sklearn_extensions.model_selection import ModelSelector
+from ._lazy import _LazyModule
+
+_rlearner = _LazyModule("econml.dml._rlearner") # lazy: avoid circular import
try:
import ray
@@ -1149,9 +1152,7 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, s
}
# If using an _rlearner, the scoring parameter can be passed along, if provided
if scoring is not None:
- # Cannot import in header, or circular imports
- from .dml._rlearner import _ModelFinal
- if isinstance(self._ortho_learner_model_final, _ModelFinal):
+ if isinstance(self._ortho_learner_model_final, _rlearner._ModelFinal):
score_kwargs['scoring'] = scoring
else:
raise NotImplementedError("scoring parameter only implemented for "
diff --git a/econml/_shap.py b/econml/_shap.py
index a0d75642c..2c84e6f02 100644
--- a/econml/_shap.py
+++ b/econml/_shap.py
@@ -13,11 +13,13 @@
"""
import inspect
-import shap
from collections import defaultdict
import numpy as np
+from ._lazy import _LazyModule
from .utilities import broadcast_unit_treatments, cross_product, get_feature_names_or_default
+shap = _LazyModule("shap") # lazy: heavy dependency only needed when shap_values() is called
+
def _shap_explain_cme(cme_model, X, d_t, d_y,
feature_names=None, treatment_names=None, output_names=None,
diff --git a/econml/data/dynamic_panel_dgp.py b/econml/data/dynamic_panel_dgp.py
index 0eaeb88f3..e14429d98 100644
--- a/econml/data/dynamic_panel_dgp.py
+++ b/econml/data/dynamic_panel_dgp.py
@@ -1,6 +1,5 @@
import numpy as np
-from econml.utilities import cross_product
-from statsmodels.tools.tools import add_constant
+from econml.utilities import cross_product, add_constant
import pandas as pd
import scipy as sp
from scipy.stats import expon
diff --git a/econml/dml/causal_forest.py b/econml/dml/causal_forest.py
index e83ca4d4e..d6ff14eeb 100644
--- a/econml/dml/causal_forest.py
+++ b/econml/dml/causal_forest.py
@@ -17,9 +17,12 @@
from .._cate_estimator import LinearCateEstimator
from .._shap import _shap_explain_multitask_model_cate
from .._ortho_learner import _OrthoLearner
+from .._lazy import _LazyModule
from ..validate.sensitivity_analysis import (sensitivity_interval, RV, dml_sensitivity_values,
sensitivity_summary)
+_score = _LazyModule("econml.score") # lazy: avoid circular import
+
class _CausalForestFinalWrapper:
@@ -757,7 +760,7 @@ def tune(self, Y, T, *, X=None, W=None,
The tuned causal forest object. This is the same object (not a copy) as the original one, but where
all parameters of the object have been set to the best performing parameters from the tuning grid.
"""
- from ..score import RScorer # import here to avoid circular import issue
+ RScorer = _score.RScorer
Y, T, X, sample_weight, groups = check_input_arrays(Y, T, X, sample_weight, groups)
W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True,
ensure_2d=True)
diff --git a/econml/inference/_bootstrap.py b/econml/inference/_bootstrap.py
index 7993a9c74..36089bd63 100644
--- a/econml/inference/_bootstrap.py
+++ b/econml/inference/_bootstrap.py
@@ -6,6 +6,9 @@
from joblib import Parallel, delayed
from sklearn.base import clone
from scipy.stats import norm
+from .._lazy import _LazyModule
+
+_cate_estimator = _LazyModule("econml._cate_estimator") # lazy: avoid circular import
class BootstrapEstimator:
@@ -83,10 +86,8 @@ def fit(self, *args, **named_args):
The full signature of this method is the same as that of the wrapped object's `fit` method.
"""
- from .._cate_estimator import BaseCateEstimator # need to nest this here to avoid circular import
-
index_chunks = None
- if isinstance(self._instances[0], BaseCateEstimator):
+ if isinstance(self._instances[0], _cate_estimator.BaseCateEstimator):
index_chunks = self._instances[0]._strata(*args, **named_args)
if index_chunks is not None:
index_chunks = self.__stratified_indices(index_chunks)
diff --git a/econml/sklearn_extensions/linear_model.py b/econml/sklearn_extensions/linear_model.py
index 65efb5fad..8c44ee31f 100644
--- a/econml/sklearn_extensions/linear_model.py
+++ b/econml/sklearn_extensions/linear_model.py
@@ -20,7 +20,7 @@
import warnings
from collections.abc import Iterable
from scipy.stats import norm
-from ..utilities import ndim, shape, reshape, _safe_norm_ppf, check_input_arrays
+from ..utilities import ndim, shape, reshape, _safe_norm_ppf, check_input_arrays, add_constant
import sklearn
from sklearn import clone
from sklearn.linear_model import LinearRegression, LassoCV, MultiTaskLassoCV, Lasso, MultiTaskLasso
@@ -33,12 +33,14 @@
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import check_is_fitted
from sklearn.base import BaseEstimator
-from statsmodels.tools.tools import add_constant
-from statsmodels.api import RLM
-import statsmodels
+from .._lazy import _LazyModule
from joblib import Parallel, delayed
from typing import List
+_statsmodels_api = _LazyModule("statsmodels.api") # lazy: only needed for RLM
+_statsmodels = _LazyModule("statsmodels") # lazy: only needed for RLM robust norms
+_model_selection = _LazyModule("econml.sklearn_extensions.model_selection") # lazy: avoid circular import
+
class _WeightedCVIterableWrapper(_CVIterableWrapper):
def __init__(self, cv):
@@ -56,15 +58,13 @@ def split(self, X=None, y=None, groups=None, sample_weight=None):
def _weighted_check_cv(cv=5, y=None, classifier=False, random_state=None):
- # local import to avoid circular imports
- from .model_selection import WeightedKFold, WeightedStratifiedKFold
cv = 5 if cv is None else cv
if isinstance(cv, numbers.Integral):
if (classifier and (y is not None) and
(type_of_target(y) in ('binary', 'multiclass'))):
- return WeightedStratifiedKFold(cv, random_state=random_state)
+ return _model_selection.WeightedStratifiedKFold(cv, random_state=random_state)
else:
- return WeightedKFold(cv, random_state=random_state)
+ return _model_selection.WeightedKFold(cv, random_state=random_state)
if not hasattr(cv, 'split') or isinstance(cv, str):
if not isinstance(cv, Iterable) or isinstance(cv, str):
@@ -2041,9 +2041,9 @@ def fit(self, X, y):
self._n_out = 0 if len(y.shape) == 1 else (y.shape[1],)
def model_gen(y):
- return RLM(endog=y,
+ return _statsmodels_api.RLM(endog=y,
exog=X,
- M=statsmodels.robust.norms.HuberT(t=self.t)).fit(cov=self.cov_type,
+ M=_statsmodels.robust.norms.HuberT(t=self.t)).fit(cov=self.cov_type,
maxiter=self.maxiter,
tol=self.tol)
if y.ndim < 2:
diff --git a/econml/utilities.py b/econml/utilities.py
index 5673771c8..f378455b8 100644
--- a/econml/utilities.py
+++ b/econml/utilities.py
@@ -19,15 +19,60 @@
from sklearn.preprocessing import OneHotEncoder, PolynomialFeatures, LabelEncoder
import warnings
from warnings import warn
-from statsmodels.iolib.table import SimpleTable
-from statsmodels.iolib.summary import summary_return
+from ._lazy import _LazyModule
from inspect import signature
from packaging.version import parse
+_statsmodels_table = _LazyModule("statsmodels.iolib.table") # lazy: only needed for Summary output
+_statsmodels_summary = _LazyModule("statsmodels.iolib.summary") # lazy: only needed for Summary output
+
MAX_RAND_SEED = np.iinfo(np.int32).max
+def add_constant(data, prepend=True, has_constant='skip'):
+ """Add a column of ones to a numpy array.
+
+ Parameters
+ ----------
+ data : array_like
+ A column-ordered design matrix.
+ prepend : bool, default True
+ If True the constant is in the first column, else appended.
+ has_constant : {'skip', 'add', 'raise'}, default 'skip'
+ Behavior when *data* already contains a constant column.
+ ``'skip'`` returns *data* unchanged, ``'raise'`` raises
+ ``ValueError``, ``'add'`` adds another column of ones anyway.
+
+ Returns
+ -------
+ ndarray
+ The array with a ones column prepended (or appended).
+ """
+ x = np.asarray(data)
+ if isinstance(data, pd.DataFrame):
+ raise TypeError(
+ "add_constant does not support pandas DataFrames; "
+ "pass a numpy array instead"
+ )
+ if x.ndim == 1:
+ x = x[:, None]
+ elif x.ndim > 2:
+ raise ValueError('Only implemented for 2-dimensional arrays')
+
+ if has_constant != 'add':
+ is_const = (np.ptp(x, axis=0) == 0) & np.all(x != 0.0, axis=0)
+ if is_const.any():
+ if has_constant == 'skip':
+ return x
+ cols = ",".join(str(c) for c in np.where(is_const)[0])
+ raise ValueError(f"Column(s) {cols} are constant.")
+
+ ones = np.ones(x.shape[0])
+ parts = [ones, x] if prepend else [x, ones]
+ return np.column_stack(parts)
+
+
class IdentityFeatures(TransformerMixin):
"""Featurizer that just returns the input data."""
@@ -1147,7 +1192,7 @@ def _repr_html_(self):
return self.as_html()
def add_table(self, res, header, index, title):
- table = SimpleTable(res, header, index, title)
+ table = _statsmodels_table.SimpleTable(res, header, index, title)
self.tables.append(table)
def add_extra_txt(self, etext):
@@ -1170,7 +1215,7 @@ def as_text(self):
summary tables and extra text as one string
"""
- txt = summary_return(self.tables, return_fmt='text')
+ txt = _statsmodels_summary.summary_return(self.tables, return_fmt='text')
if self.extra_txt is not None:
txt = txt + '\n\n' + self.extra_txt
return txt
@@ -1190,7 +1235,7 @@ def as_latex(self):
tables.
"""
- latex = summary_return(self.tables, return_fmt='latex')
+ latex = _statsmodels_summary.summary_return(self.tables, return_fmt='latex')
if self.extra_txt is not None:
latex = latex + '\n\n' + self.extra_txt.replace('\n', ' \\newline\n ')
return latex
@@ -1204,7 +1249,7 @@ def as_csv(self):
concatenated summary tables in comma delimited format
"""
- csv = summary_return(self.tables, return_fmt='csv')
+ csv = _statsmodels_summary.summary_return(self.tables, return_fmt='csv')
if self.extra_txt is not None:
csv = csv + '\n\n' + self.extra_txt
return csv
@@ -1218,7 +1263,7 @@ def as_html(self):
concatenated summary tables in HTML format
"""
- html = summary_return(self.tables, return_fmt='html')
+ html = _statsmodels_summary.summary_return(self.tables, return_fmt='html')
if self.extra_txt is not None:
html = html + '
' + self.extra_txt.replace('\n', '
')
return html
diff --git a/econml/validate/drtester.py b/econml/validate/drtester.py
index c79330218..e90025786 100644
--- a/econml/validate/drtester.py
+++ b/econml/validate/drtester.py
@@ -5,14 +5,13 @@
import scipy.stats as st
from sklearn.model_selection import check_cv
from sklearn.model_selection import cross_val_predict, StratifiedKFold, KFold
-from statsmodels.api import OLS
-from statsmodels.tools import add_constant
-
-from econml.utilities import deprecated
-
+from econml._lazy import _LazyModule
+from econml.utilities import deprecated, add_constant
from .results import CalibrationEvaluationResults, BLPEvaluationResults, UpliftEvaluationResults, EvaluationResults
from .utils import calculate_dr_outcomes, calc_uplift
+_statsmodels_api = _LazyModule("statsmodels.api") # lazy: only needed for evaluate_blp()
+
class DRTester:
"""
@@ -479,7 +478,7 @@ def evaluate_blp(
self.get_cate_preds(Xval, Xtrain)
if self.n_treat == 1: # binary treatment
- reg = OLS(self.dr_val_, add_constant(self.cate_preds_val_)).fit()
+ reg = _statsmodels_api.OLS(self.dr_val_, add_constant(self.cate_preds_val_)).fit()
params = [reg.params[1]]
errs = [reg.bse[1]]
pvals = [reg.pvalues[1]]
@@ -488,7 +487,10 @@ def evaluate_blp(
errs = []
pvals = []
for k in range(self.n_treat): # run a separate regression for each
- reg = OLS(self.dr_val_[:, k], add_constant(self.cate_preds_val_[:, k])).fit(cov_type='HC1')
+ reg = _statsmodels_api.OLS(
+ self.dr_val_[:, k],
+ add_constant(self.cate_preds_val_[:, k])
+ ).fit(cov_type='HC1')
params.append(reg.params[1])
errs.append(reg.bse[1])
pvals.append(reg.pvalues[1])