diff --git a/causalml/inference/tree/uplift.pyx b/causalml/inference/tree/uplift.pyx index 1f60b908..272e2ae5 100644 --- a/causalml/inference/tree/uplift.pyx +++ b/causalml/inference/tree/uplift.pyx @@ -33,6 +33,7 @@ from joblib import Parallel, delayed from packaging import version from sklearn.model_selection import train_test_split from sklearn.utils import check_X_y, check_array, check_random_state +import numbers if version.parse(sklearn.__version__) >= version.parse('0.22.0'): from sklearn.utils._testing import ignore_warnings @@ -2264,7 +2265,7 @@ class UpliftTreeClassifier: else: v = observations[tree.col] branch = None - if isinstance(v, int) or isinstance(v, float): + if isinstance(v, numbers.Number): if v >= tree.value: branch = tree.trueBranch else: @@ -2311,7 +2312,7 @@ class UpliftTreeClassifier: return dict(result) else: branch = None - if isinstance(v, int) or isinstance(v, float): + if isinstance(v, numbers.Number): if v >= tree.value: branch = tree.trueBranch else: