From db8d9c584dc676c68d910ceac6d4b328d97e4dfd Mon Sep 17 00:00:00 2001 From: Keith Battocchi Date: Fri, 10 Apr 2026 12:07:49 -0400 Subject: [PATCH 1/3] Lazy-load shap and statsmodels to reduce import overhead Add a _LazyModule proxy class (econml/_lazy.py) that defers module loading until first attribute access. This keeps lazy import declarations at the top of each file alongside normal imports, making the deferred loading explicit and avoiding scattered inline imports inside function bodies. Modules deferred: - shap (+numba, sparse) in econml/_shap.py - statsmodels.iolib.{table,summary} in econml/utilities.py - statsmodels.{tools,api,robust} in econml/sklearn_extensions/linear_model.py - statsmodels.tools.tools in econml/data/dynamic_panel_dgp.py - statsmodels.{api,tools} in econml/validate/drtester.py Measured improvement: single-test cold start ~12s -> ~7s (39% faster). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Keith Battocchi --- econml/_lazy.py | 40 +++++++++++++++++++++++ econml/_shap.py | 4 ++- econml/data/dynamic_panel_dgp.py | 5 ++- econml/sklearn_extensions/linear_model.py | 20 +++++++----- econml/utilities.py | 16 +++++---- econml/validate/drtester.py | 15 +++++---- 6 files changed, 76 insertions(+), 24 deletions(-) create mode 100644 econml/_lazy.py diff --git a/econml/_lazy.py b/econml/_lazy.py new file mode 100644 index 000000000..ce151b0c8 --- /dev/null +++ b/econml/_lazy.py @@ -0,0 +1,40 @@ +# Copyright (c) PyWhy contributors. All rights reserved. +# Licensed under the MIT License. + +"""Lazy module loading to avoid expensive imports at package load time.""" + +import importlib + + +class _LazyModule: + """Proxy that delays importing a module until an attribute is accessed. + + Use at module level as a drop-in replacement for ``import heavy_lib``:: + + heavy_lib = _LazyModule("heavy_lib") + + The real module is imported on first attribute access, so the cost is + deferred until the functionality is actually needed. + """ + + def __init__(self, module_name): + object.__setattr__(self, "_module_name", module_name) + object.__setattr__(self, "_module", None) + + def _load(self): + module = object.__getattribute__(self, "_module") + if module is None: + name = object.__getattribute__(self, "_module_name") + module = importlib.import_module(name) + object.__setattr__(self, "_module", module) + return module + + def __getattr__(self, name): + return getattr(self._load(), name) + + def __repr__(self): + module = object.__getattribute__(self, "_module") + if module is not None: + return repr(module) + name = object.__getattribute__(self, "_module_name") + return f"<_LazyModule '{name}' (not yet loaded)>" diff --git a/econml/_shap.py b/econml/_shap.py index a0d75642c..2c84e6f02 100644 --- a/econml/_shap.py +++ b/econml/_shap.py @@ -13,11 +13,13 @@ """ import inspect -import shap from collections import defaultdict import numpy as np +from ._lazy import _LazyModule from .utilities import broadcast_unit_treatments, cross_product, get_feature_names_or_default +shap = _LazyModule("shap") # lazy: heavy dependency only needed when shap_values() is called + def _shap_explain_cme(cme_model, X, d_t, d_y, feature_names=None, treatment_names=None, output_names=None, diff --git a/econml/data/dynamic_panel_dgp.py b/econml/data/dynamic_panel_dgp.py index 0eaeb88f3..c59cd51b7 100644 --- a/econml/data/dynamic_panel_dgp.py +++ b/econml/data/dynamic_panel_dgp.py @@ -1,6 +1,6 @@ import numpy as np from econml.utilities import cross_product -from statsmodels.tools.tools import add_constant +from econml._lazy import _LazyModule import pandas as pd import scipy as sp from scipy.stats import expon @@ -9,6 +9,8 @@ import joblib import os +_statsmodels_tools = _LazyModule("statsmodels.tools.tools") # lazy: only needed in create_instance() + dir = os.path.dirname(__file__) @@ -304,6 +306,7 @@ def create_instance(self, s_x, sigma_x, sigma_y, conf_str, epsilon, Alpha_unnorm self.true_effect[t, :] = (self.zeta.reshape( 1, -1) @ np.linalg.matrix_power(self.Beta, t - 1) @ self.Alpha) + add_constant = _statsmodels_tools.add_constant self.true_hetero_effect = np.zeros( (self.n_periods, (self.n_x + 1) * self.n_treatments)) self.true_hetero_effect[0, :] = cross_product(add_constant(self.y_hetero_effect.reshape(1, -1), diff --git a/econml/sklearn_extensions/linear_model.py b/econml/sklearn_extensions/linear_model.py index 65efb5fad..6bb271929 100644 --- a/econml/sklearn_extensions/linear_model.py +++ b/econml/sklearn_extensions/linear_model.py @@ -33,12 +33,14 @@ from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import check_is_fitted from sklearn.base import BaseEstimator -from statsmodels.tools.tools import add_constant -from statsmodels.api import RLM -import statsmodels +from .._lazy import _LazyModule from joblib import Parallel, delayed from typing import List +_statsmodels_tools = _LazyModule("statsmodels.tools.tools") # lazy: only needed in fit/predict methods +_statsmodels_api = _LazyModule("statsmodels.api") # lazy: only needed for RLM +_statsmodels = _LazyModule("statsmodels") # lazy: only needed for RLM robust norms + class _WeightedCVIterableWrapper(_CVIterableWrapper): def __init__(self, cv): @@ -1539,7 +1541,7 @@ def predict(self, X): if X is None: X = np.empty((1, 0)) if self.fit_intercept: - X = add_constant(X, has_constant='add') + X = _statsmodels_tools.add_constant(X, has_constant='add') return np.matmul(X, self._param) @property @@ -1634,7 +1636,7 @@ def prediction_stderr(self, X): if X is None: X = np.empty((1, 0)) if self.fit_intercept: - X = add_constant(X, has_constant='add') + X = _statsmodels_tools.add_constant(X, has_constant='add') if self._n_out == 0: return np.sqrt(np.clip(np.sum(np.matmul(X, self._param_var) * X, axis=1), 0, np.inf)) else: @@ -1735,7 +1737,7 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var): if X is None: X = np.empty((y.shape[0], 0)) if self.fit_intercept: - X = add_constant(X, has_constant='add') + X = _statsmodels_tools.add_constant(X, has_constant='add') # set default values for None if sample_weight is None: @@ -2036,14 +2038,14 @@ def fit(self, X, y): """ X, y = self._check_input(X, y) if self.fit_intercept: - X = add_constant(X, has_constant='add') + X = _statsmodels_tools.add_constant(X, has_constant='add') self._n_out = 0 if len(y.shape) == 1 else (y.shape[1],) def model_gen(y): - return RLM(endog=y, + return _statsmodels_api.RLM(endog=y, exog=X, - M=statsmodels.robust.norms.HuberT(t=self.t)).fit(cov=self.cov_type, + M=_statsmodels.robust.norms.HuberT(t=self.t)).fit(cov=self.cov_type, maxiter=self.maxiter, tol=self.tol) if y.ndim < 2: diff --git a/econml/utilities.py b/econml/utilities.py index 5673771c8..3415fc0a2 100644 --- a/econml/utilities.py +++ b/econml/utilities.py @@ -19,11 +19,13 @@ from sklearn.preprocessing import OneHotEncoder, PolynomialFeatures, LabelEncoder import warnings from warnings import warn -from statsmodels.iolib.table import SimpleTable -from statsmodels.iolib.summary import summary_return +from ._lazy import _LazyModule from inspect import signature from packaging.version import parse +_statsmodels_table = _LazyModule("statsmodels.iolib.table") # lazy: only needed for Summary output +_statsmodels_summary = _LazyModule("statsmodels.iolib.summary") # lazy: only needed for Summary output + MAX_RAND_SEED = np.iinfo(np.int32).max @@ -1147,7 +1149,7 @@ def _repr_html_(self): return self.as_html() def add_table(self, res, header, index, title): - table = SimpleTable(res, header, index, title) + table = _statsmodels_table.SimpleTable(res, header, index, title) self.tables.append(table) def add_extra_txt(self, etext): @@ -1170,7 +1172,7 @@ def as_text(self): summary tables and extra text as one string """ - txt = summary_return(self.tables, return_fmt='text') + txt = _statsmodels_summary.summary_return(self.tables, return_fmt='text') if self.extra_txt is not None: txt = txt + '\n\n' + self.extra_txt return txt @@ -1190,7 +1192,7 @@ def as_latex(self): tables. """ - latex = summary_return(self.tables, return_fmt='latex') + latex = _statsmodels_summary.summary_return(self.tables, return_fmt='latex') if self.extra_txt is not None: latex = latex + '\n\n' + self.extra_txt.replace('\n', ' \\newline\n ') return latex @@ -1204,7 +1206,7 @@ def as_csv(self): concatenated summary tables in comma delimited format """ - csv = summary_return(self.tables, return_fmt='csv') + csv = _statsmodels_summary.summary_return(self.tables, return_fmt='csv') if self.extra_txt is not None: csv = csv + '\n\n' + self.extra_txt return csv @@ -1218,7 +1220,7 @@ def as_html(self): concatenated summary tables in HTML format """ - html = summary_return(self.tables, return_fmt='html') + html = _statsmodels_summary.summary_return(self.tables, return_fmt='html') if self.extra_txt is not None: html = html + '

' + self.extra_txt.replace('\n', '
') return html diff --git a/econml/validate/drtester.py b/econml/validate/drtester.py index c79330218..85f08271f 100644 --- a/econml/validate/drtester.py +++ b/econml/validate/drtester.py @@ -5,14 +5,14 @@ import scipy.stats as st from sklearn.model_selection import check_cv from sklearn.model_selection import cross_val_predict, StratifiedKFold, KFold -from statsmodels.api import OLS -from statsmodels.tools import add_constant - +from econml._lazy import _LazyModule from econml.utilities import deprecated - from .results import CalibrationEvaluationResults, BLPEvaluationResults, UpliftEvaluationResults, EvaluationResults from .utils import calculate_dr_outcomes, calc_uplift +_statsmodels_api = _LazyModule("statsmodels.api") # lazy: only needed for evaluate_blp() +_statsmodels_tools = _LazyModule("statsmodels.tools") # lazy: only needed for evaluate_blp() + class DRTester: """ @@ -479,7 +479,7 @@ def evaluate_blp( self.get_cate_preds(Xval, Xtrain) if self.n_treat == 1: # binary treatment - reg = OLS(self.dr_val_, add_constant(self.cate_preds_val_)).fit() + reg = _statsmodels_api.OLS(self.dr_val_, _statsmodels_tools.add_constant(self.cate_preds_val_)).fit() params = [reg.params[1]] errs = [reg.bse[1]] pvals = [reg.pvalues[1]] @@ -488,7 +488,10 @@ def evaluate_blp( errs = [] pvals = [] for k in range(self.n_treat): # run a separate regression for each - reg = OLS(self.dr_val_[:, k], add_constant(self.cate_preds_val_[:, k])).fit(cov_type='HC1') + reg = _statsmodels_api.OLS( + self.dr_val_[:, k], + _statsmodels_tools.add_constant(self.cate_preds_val_[:, k]) + ).fit(cov_type='HC1') params.append(reg.params[1]) errs.append(reg.bse[1]) pvals.append(reg.pvalues[1]) From c39680ad86b69cd0ce48ae85d6698af9bc2febbd Mon Sep 17 00:00:00 2001 From: Keith Battocchi Date: Fri, 10 Apr 2026 12:35:27 -0400 Subject: [PATCH 2/3] Replace statsmodels.add_constant with local implementation Add a lightweight add_constant() to econml/utilities.py that handles the numpy-array case directly, with a guard that raises TypeError for pandas DataFrames. This eliminates the statsmodels dependency from dynamic_panel_dgp.py entirely, and removes the _statsmodels_tools lazy import from linear_model.py and drtester.py. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Keith Battocchi --- econml/data/dynamic_panel_dgp.py | 6 +--- econml/sklearn_extensions/linear_model.py | 11 +++--- econml/utilities.py | 43 +++++++++++++++++++++++ econml/validate/drtester.py | 7 ++-- 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/econml/data/dynamic_panel_dgp.py b/econml/data/dynamic_panel_dgp.py index c59cd51b7..e14429d98 100644 --- a/econml/data/dynamic_panel_dgp.py +++ b/econml/data/dynamic_panel_dgp.py @@ -1,6 +1,5 @@ import numpy as np -from econml.utilities import cross_product -from econml._lazy import _LazyModule +from econml.utilities import cross_product, add_constant import pandas as pd import scipy as sp from scipy.stats import expon @@ -9,8 +8,6 @@ import joblib import os -_statsmodels_tools = _LazyModule("statsmodels.tools.tools") # lazy: only needed in create_instance() - dir = os.path.dirname(__file__) @@ -306,7 +303,6 @@ def create_instance(self, s_x, sigma_x, sigma_y, conf_str, epsilon, Alpha_unnorm self.true_effect[t, :] = (self.zeta.reshape( 1, -1) @ np.linalg.matrix_power(self.Beta, t - 1) @ self.Alpha) - add_constant = _statsmodels_tools.add_constant self.true_hetero_effect = np.zeros( (self.n_periods, (self.n_x + 1) * self.n_treatments)) self.true_hetero_effect[0, :] = cross_product(add_constant(self.y_hetero_effect.reshape(1, -1), diff --git a/econml/sklearn_extensions/linear_model.py b/econml/sklearn_extensions/linear_model.py index 6bb271929..4d774742f 100644 --- a/econml/sklearn_extensions/linear_model.py +++ b/econml/sklearn_extensions/linear_model.py @@ -20,7 +20,7 @@ import warnings from collections.abc import Iterable from scipy.stats import norm -from ..utilities import ndim, shape, reshape, _safe_norm_ppf, check_input_arrays +from ..utilities import ndim, shape, reshape, _safe_norm_ppf, check_input_arrays, add_constant import sklearn from sklearn import clone from sklearn.linear_model import LinearRegression, LassoCV, MultiTaskLassoCV, Lasso, MultiTaskLasso @@ -37,7 +37,6 @@ from joblib import Parallel, delayed from typing import List -_statsmodels_tools = _LazyModule("statsmodels.tools.tools") # lazy: only needed in fit/predict methods _statsmodels_api = _LazyModule("statsmodels.api") # lazy: only needed for RLM _statsmodels = _LazyModule("statsmodels") # lazy: only needed for RLM robust norms @@ -1541,7 +1540,7 @@ def predict(self, X): if X is None: X = np.empty((1, 0)) if self.fit_intercept: - X = _statsmodels_tools.add_constant(X, has_constant='add') + X = add_constant(X, has_constant='add') return np.matmul(X, self._param) @property @@ -1636,7 +1635,7 @@ def prediction_stderr(self, X): if X is None: X = np.empty((1, 0)) if self.fit_intercept: - X = _statsmodels_tools.add_constant(X, has_constant='add') + X = add_constant(X, has_constant='add') if self._n_out == 0: return np.sqrt(np.clip(np.sum(np.matmul(X, self._param_var) * X, axis=1), 0, np.inf)) else: @@ -1737,7 +1736,7 @@ def _check_input(self, X, y, sample_weight, freq_weight, sample_var): if X is None: X = np.empty((y.shape[0], 0)) if self.fit_intercept: - X = _statsmodels_tools.add_constant(X, has_constant='add') + X = add_constant(X, has_constant='add') # set default values for None if sample_weight is None: @@ -2038,7 +2037,7 @@ def fit(self, X, y): """ X, y = self._check_input(X, y) if self.fit_intercept: - X = _statsmodels_tools.add_constant(X, has_constant='add') + X = add_constant(X, has_constant='add') self._n_out = 0 if len(y.shape) == 1 else (y.shape[1],) diff --git a/econml/utilities.py b/econml/utilities.py index 3415fc0a2..f378455b8 100644 --- a/econml/utilities.py +++ b/econml/utilities.py @@ -30,6 +30,49 @@ MAX_RAND_SEED = np.iinfo(np.int32).max +def add_constant(data, prepend=True, has_constant='skip'): + """Add a column of ones to a numpy array. + + Parameters + ---------- + data : array_like + A column-ordered design matrix. + prepend : bool, default True + If True the constant is in the first column, else appended. + has_constant : {'skip', 'add', 'raise'}, default 'skip' + Behavior when *data* already contains a constant column. + ``'skip'`` returns *data* unchanged, ``'raise'`` raises + ``ValueError``, ``'add'`` adds another column of ones anyway. + + Returns + ------- + ndarray + The array with a ones column prepended (or appended). + """ + x = np.asarray(data) + if isinstance(data, pd.DataFrame): + raise TypeError( + "add_constant does not support pandas DataFrames; " + "pass a numpy array instead" + ) + if x.ndim == 1: + x = x[:, None] + elif x.ndim > 2: + raise ValueError('Only implemented for 2-dimensional arrays') + + if has_constant != 'add': + is_const = (np.ptp(x, axis=0) == 0) & np.all(x != 0.0, axis=0) + if is_const.any(): + if has_constant == 'skip': + return x + cols = ",".join(str(c) for c in np.where(is_const)[0]) + raise ValueError(f"Column(s) {cols} are constant.") + + ones = np.ones(x.shape[0]) + parts = [ones, x] if prepend else [x, ones] + return np.column_stack(parts) + + class IdentityFeatures(TransformerMixin): """Featurizer that just returns the input data.""" diff --git a/econml/validate/drtester.py b/econml/validate/drtester.py index 85f08271f..e90025786 100644 --- a/econml/validate/drtester.py +++ b/econml/validate/drtester.py @@ -6,12 +6,11 @@ from sklearn.model_selection import check_cv from sklearn.model_selection import cross_val_predict, StratifiedKFold, KFold from econml._lazy import _LazyModule -from econml.utilities import deprecated +from econml.utilities import deprecated, add_constant from .results import CalibrationEvaluationResults, BLPEvaluationResults, UpliftEvaluationResults, EvaluationResults from .utils import calculate_dr_outcomes, calc_uplift _statsmodels_api = _LazyModule("statsmodels.api") # lazy: only needed for evaluate_blp() -_statsmodels_tools = _LazyModule("statsmodels.tools") # lazy: only needed for evaluate_blp() class DRTester: @@ -479,7 +478,7 @@ def evaluate_blp( self.get_cate_preds(Xval, Xtrain) if self.n_treat == 1: # binary treatment - reg = _statsmodels_api.OLS(self.dr_val_, _statsmodels_tools.add_constant(self.cate_preds_val_)).fit() + reg = _statsmodels_api.OLS(self.dr_val_, add_constant(self.cate_preds_val_)).fit() params = [reg.params[1]] errs = [reg.bse[1]] pvals = [reg.pvalues[1]] @@ -490,7 +489,7 @@ def evaluate_blp( for k in range(self.n_treat): # run a separate regression for each reg = _statsmodels_api.OLS( self.dr_val_[:, k], - _statsmodels_tools.add_constant(self.cate_preds_val_[:, k]) + add_constant(self.cate_preds_val_[:, k]) ).fit(cov_type='HC1') params.append(reg.params[1]) errs.append(reg.bse[1]) From adbb8d8d30756cbee08139e0b101db88e7ca0f62 Mon Sep 17 00:00:00 2001 From: Keith Battocchi Date: Fri, 10 Apr 2026 13:18:23 -0400 Subject: [PATCH 3/3] Use _LazyModule to replace inline circular-import workarounds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace 4 deferred imports that existed to avoid circular imports with top-level _LazyModule declarations. The lazy proxy defers the actual importlib.import_module() call until first attribute access, which happens inside function/method bodies after all modules have finished loading — so the circular dependency is still broken, but the import declaration lives at the top of the file. - econml/dml/causal_forest.py: econml.score (RScorer) - econml/inference/_bootstrap.py: econml._cate_estimator (BaseCateEstimator) - econml/sklearn_extensions/linear_model.py: econml.sklearn_extensions.model_selection - econml/_ortho_learner.py: econml.dml._rlearner (_ModelFinal) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Keith Battocchi --- econml/_ortho_learner.py | 7 ++++--- econml/dml/causal_forest.py | 5 ++++- econml/inference/_bootstrap.py | 7 ++++--- econml/sklearn_extensions/linear_model.py | 7 +++---- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/econml/_ortho_learner.py b/econml/_ortho_learner.py index cf966b631..e27b47a67 100644 --- a/econml/_ortho_learner.py +++ b/econml/_ortho_learner.py @@ -41,6 +41,9 @@ class in this module implements the general logic in a very versatile way filter_none_kwargs, one_hot_encoder, strata_from_discrete_arrays, jacify_featurizer, reshape, shape) from .sklearn_extensions.model_selection import ModelSelector +from ._lazy import _LazyModule + +_rlearner = _LazyModule("econml.dml._rlearner") # lazy: avoid circular import try: import ray @@ -1149,9 +1152,7 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None, s } # If using an _rlearner, the scoring parameter can be passed along, if provided if scoring is not None: - # Cannot import in header, or circular imports - from .dml._rlearner import _ModelFinal - if isinstance(self._ortho_learner_model_final, _ModelFinal): + if isinstance(self._ortho_learner_model_final, _rlearner._ModelFinal): score_kwargs['scoring'] = scoring else: raise NotImplementedError("scoring parameter only implemented for " diff --git a/econml/dml/causal_forest.py b/econml/dml/causal_forest.py index e83ca4d4e..d6ff14eeb 100644 --- a/econml/dml/causal_forest.py +++ b/econml/dml/causal_forest.py @@ -17,9 +17,12 @@ from .._cate_estimator import LinearCateEstimator from .._shap import _shap_explain_multitask_model_cate from .._ortho_learner import _OrthoLearner +from .._lazy import _LazyModule from ..validate.sensitivity_analysis import (sensitivity_interval, RV, dml_sensitivity_values, sensitivity_summary) +_score = _LazyModule("econml.score") # lazy: avoid circular import + class _CausalForestFinalWrapper: @@ -757,7 +760,7 @@ def tune(self, Y, T, *, X=None, W=None, The tuned causal forest object. This is the same object (not a copy) as the original one, but where all parameters of the object have been set to the best performing parameters from the tuning grid. """ - from ..score import RScorer # import here to avoid circular import issue + RScorer = _score.RScorer Y, T, X, sample_weight, groups = check_input_arrays(Y, T, X, sample_weight, groups) W, = check_input_arrays(W, force_all_finite='allow-nan' if 'W' in self._gen_allowed_missing_vars() else True, ensure_2d=True) diff --git a/econml/inference/_bootstrap.py b/econml/inference/_bootstrap.py index 7993a9c74..36089bd63 100644 --- a/econml/inference/_bootstrap.py +++ b/econml/inference/_bootstrap.py @@ -6,6 +6,9 @@ from joblib import Parallel, delayed from sklearn.base import clone from scipy.stats import norm +from .._lazy import _LazyModule + +_cate_estimator = _LazyModule("econml._cate_estimator") # lazy: avoid circular import class BootstrapEstimator: @@ -83,10 +86,8 @@ def fit(self, *args, **named_args): The full signature of this method is the same as that of the wrapped object's `fit` method. """ - from .._cate_estimator import BaseCateEstimator # need to nest this here to avoid circular import - index_chunks = None - if isinstance(self._instances[0], BaseCateEstimator): + if isinstance(self._instances[0], _cate_estimator.BaseCateEstimator): index_chunks = self._instances[0]._strata(*args, **named_args) if index_chunks is not None: index_chunks = self.__stratified_indices(index_chunks) diff --git a/econml/sklearn_extensions/linear_model.py b/econml/sklearn_extensions/linear_model.py index 4d774742f..8c44ee31f 100644 --- a/econml/sklearn_extensions/linear_model.py +++ b/econml/sklearn_extensions/linear_model.py @@ -39,6 +39,7 @@ _statsmodels_api = _LazyModule("statsmodels.api") # lazy: only needed for RLM _statsmodels = _LazyModule("statsmodels") # lazy: only needed for RLM robust norms +_model_selection = _LazyModule("econml.sklearn_extensions.model_selection") # lazy: avoid circular import class _WeightedCVIterableWrapper(_CVIterableWrapper): @@ -57,15 +58,13 @@ def split(self, X=None, y=None, groups=None, sample_weight=None): def _weighted_check_cv(cv=5, y=None, classifier=False, random_state=None): - # local import to avoid circular imports - from .model_selection import WeightedKFold, WeightedStratifiedKFold cv = 5 if cv is None else cv if isinstance(cv, numbers.Integral): if (classifier and (y is not None) and (type_of_target(y) in ('binary', 'multiclass'))): - return WeightedStratifiedKFold(cv, random_state=random_state) + return _model_selection.WeightedStratifiedKFold(cv, random_state=random_state) else: - return WeightedKFold(cv, random_state=random_state) + return _model_selection.WeightedKFold(cv, random_state=random_state) if not hasattr(cv, 'split') or isinstance(cv, str): if not isinstance(cv, Iterable) or isinstance(cv, str):