diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index 4f0d26a9dfbfb..d67949d05010c 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -634,6 +634,20 @@ Mean Absolute Error: Note that it is 3–6× slower to fit than the MSE criterion as of version 1.8. +Quantile (pinball loss): + +.. math:: + + q_{\alpha}(y)_m = \underset{y \in Q_m}{\mathrm{quantile}_{\alpha}}(y) + + H(Q_m) = \frac{1}{n_m} \sum_{y \in Q_m} + \left(\alpha \max(y - q_{\alpha}(y)_m, 0) + + (1-\alpha) \max(q_{\alpha}(y)_m - y, 0)\right) + +Use ``criterion="quantile"`` together with the ``quantile`` parameter to +choose :math:`\alpha \in (0, 1)`. The special case ``quantile=0.5`` corresponds +to the median. + .. _tree_missing_value_support: Missing Values Support diff --git a/doc/whats_new/upcoming_changes/sklearn.tree/32903.feature.rst b/doc/whats_new/upcoming_changes/sklearn.tree/32903.feature.rst new file mode 100644 index 0000000000000..e2107b00154f2 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.tree/32903.feature.rst @@ -0,0 +1,6 @@ +- :class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeRegressor`, + :class:`ensemble.RandomForestRegressor`, and :class:`ensemble.ExtraTreesRegressor` + now support `criterion="quantile"` together with the `quantile` parameter to + optimize the pinball loss (also known as the quantile loss). This effectively + allows to do quantile regression. + By :user:`Arthur Lacote ` diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 6df5152e04273..debfe5d9d7489 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -345,7 +345,10 @@ def fit(self, X, y, sample_weight=None): # will raise an error if the underlying tree base estimator can't handle missing # values. Only the criterion is required to determine if the tree supports # missing values. - estimator = type(self.estimator)(criterion=self.criterion) + estimator_kwargs = {"criterion": self.criterion} + if self.criterion == "quantile": + estimator_kwargs["quantile"] = self.quantile + estimator = type(self.estimator)(**estimator_kwargs) missing_values_in_feature_mask = ( estimator._compute_missing_values_in_feature_mask( X, estimator_name=self.__class__.__name__ @@ -697,7 +700,10 @@ def __sklearn_tags__(self): tags = super().__sklearn_tags__() # Only the criterion is required to determine if the tree supports # missing values - estimator = type(self.estimator)(criterion=self.criterion) + estimator_kwargs = {"criterion": self.criterion} + if self.criterion == "quantile": + estimator_kwargs["quantile"] = self.quantile + estimator = type(self.estimator)(**estimator_kwargs) tags.input_tags.allow_nan = get_tags(estimator).input_tags.allow_nan return tags @@ -1609,14 +1615,16 @@ class RandomForestRegressor(ForestRegressor): The default value of ``n_estimators`` changed from 10 to 100 in 0.22. - criterion : {"squared_error", "absolute_error", "poisson"}, default="squared_error" + criterion : {"squared_error", "absolute_error", "quantile", "poisson"}, \ + default="squared_error" The function to measure the quality of a split. Supported criteria are "squared_error" for the mean squared error, which is equal to variance reduction as feature selection criterion and minimizes the L2 loss using the mean of each terminal node, "absolute_error" for the mean absolute error, which minimizes the L1 loss using the median of each terminal - node, and "poisson" which uses reduction in Poisson deviance to find splits, - also using the mean of each terminal node. + node, "quantile" which minimizes the pinball loss using the quantile of each + terminal node (controlled by ``quantile``), and "poisson" which uses reduction + in Poisson deviance to find splits, also using the mean of each terminal node. .. versionadded:: 0.18 Mean Absolute Error (MAE) criterion. @@ -1627,6 +1635,9 @@ class RandomForestRegressor(ForestRegressor): .. versionchanged:: 1.9 Criterion `"friedman_mse"` was deprecated. + .. versionadded:: 1.9 + Quantile/Pinball loss criterion + max_depth : int, default=None The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than @@ -1786,6 +1797,10 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 1.4 + quantile : float, default=0.5 + The quantile to predict when ``criterion="quantile"``. It must be strictly + between 0 and 1. + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor` @@ -1913,6 +1928,7 @@ def __init__( ccp_alpha=0.0, max_samples=None, monotonic_cst=None, + quantile=0.5, ): super().__init__( estimator=DecisionTreeRegressor(), @@ -1929,6 +1945,7 @@ def __init__( "random_state", "ccp_alpha", "monotonic_cst", + "quantile", ), bootstrap=bootstrap, oob_score=oob_score, @@ -1959,6 +1976,7 @@ def __init__( self.min_impurity_decrease = min_impurity_decrease self.ccp_alpha = ccp_alpha self.monotonic_cst = monotonic_cst + self.quantile = quantile class ExtraTreesClassifier(ForestClassifier): @@ -2378,14 +2396,16 @@ class ExtraTreesRegressor(ForestRegressor): The default value of ``n_estimators`` changed from 10 to 100 in 0.22. - criterion : {"squared_error", "absolute_error", "poisson"}, default="squared_error" + criterion : {"squared_error", "absolute_error", "quantile", "poisson"}, \ + default="squared_error" The function to measure the quality of a split. Supported criteria are "squared_error" for the mean squared error, which is equal to variance reduction as feature selection criterion and minimizes the L2 loss using the mean of each terminal node, "absolute_error" for the mean absolute error, which minimizes the L1 loss using the median of each terminal - node, and "poisson" which uses reduction in Poisson deviance to find splits, - also using the mean of each terminal node. + node, "quantile" which minimizes the pinball loss using the quantile of each + terminal node (controlled by ``quantile``), and "poisson" which uses reduction + in Poisson deviance to find splits, also using the mean of each terminal node. .. versionadded:: 0.18 Mean Absolute Error (MAE) criterion. @@ -2393,6 +2413,9 @@ class ExtraTreesRegressor(ForestRegressor): .. versionchanged:: 1.9 Criterion `"friedman_mse"` was deprecated. + .. versionadded:: 1.9 + Quantile/Pinball loss criterion + max_depth : int, default=None The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than @@ -2556,6 +2579,10 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 1.4 + quantile : float, default=0.5 + The quantile to predict when ``criterion="quantile"``. It must be strictly + between 0 and 1. + Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` @@ -2667,6 +2694,7 @@ def __init__( ccp_alpha=0.0, max_samples=None, monotonic_cst=None, + quantile=0.5, ): super().__init__( estimator=ExtraTreeRegressor(), @@ -2683,6 +2711,7 @@ def __init__( "random_state", "ccp_alpha", "monotonic_cst", + "quantile", ), bootstrap=bootstrap, oob_score=oob_score, @@ -2713,6 +2742,7 @@ def __init__( self.min_impurity_decrease = min_impurity_decrease self.ccp_alpha = ccp_alpha self.monotonic_cst = monotonic_cst + self.quantile = quantile class RandomTreesEmbedding(TransformerMixin, BaseForest): diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 9ec8030899d18..d0ff61ed577ad 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -378,6 +378,7 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta): } _parameter_constraints.pop("splitter") _parameter_constraints.pop("monotonic_cst") + _parameter_constraints.pop("quantile") @abstractmethod def __init__( diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 7d6283300a256..6887e2809ade6 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -294,12 +294,12 @@ def test_probability(name): "name, criterion", itertools.chain( product(FOREST_CLASSIFIERS, ["gini", "log_loss"]), - product(FOREST_REGRESSORS, ["squared_error", "absolute_error"]), + product(FOREST_REGRESSORS, ["squared_error", "absolute_error", "quantile"]), ), ) def test_importances(dtype, name, criterion): tolerance = 0.01 - if name in FOREST_REGRESSORS and criterion == "absolute_error": + if name in FOREST_REGRESSORS and criterion in {"absolute_error", "quantile"}: tolerance = 0.05 # cast as dtype diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index dc83aa7d3daea..a69f62ed17c60 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -76,6 +76,7 @@ "squared_error": _criterion.MSE, "absolute_error": _criterion.MAE, "poisson": _criterion.Poisson, + "quantile": _criterion.Pinball, } DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter} @@ -382,7 +383,14 @@ def _fit( self.n_outputs_, self.n_classes_ ) else: - criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) + args = (self.n_outputs_, n_samples) + if self.criterion == "quantile": + args = (*args, self.quantile) + if self.criterion == "absolute_error": + # FIXME: this is coupled with code at a much lower level + # because of the inheritance behavior of __cinit__ + args = (*args, 0.5) + criterion = CRITERIA_REG[self.criterion](*args) else: # Make a deepcopy in case the criterion has mutable attributes that # might be shared and modified concurrently during parallel fitting @@ -1117,14 +1125,16 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): Parameters ---------- - criterion : {"squared_error", "absolute_error", "poisson"}, default="squared_error" + criterion : {"squared_error", "absolute_error", "quantile", "poisson"}, \ + default="squared_error" The function to measure the quality of a split. Supported criteria are "squared_error" for the mean squared error, which is equal to variance reduction as feature selection criterion and minimizes the L2 loss using the mean of each terminal node, "absolute_error" for the mean absolute error, which minimizes the L1 loss using the median of each terminal - node, and "poisson" which uses reduction in Poisson deviance to find splits, - also using the mean of each terminal node. + node, "quantile" which minimizes the pinball loss using the quantile of each + terminal node (controlled by ``quantile``), and "poisson" which uses reduction + in Poisson deviance to find splits, also using the mean of each terminal node. .. versionadded:: 0.18 Mean Absolute Error (MAE) criterion. @@ -1135,6 +1145,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionchanged:: 1.9 Criterion `"friedman_mse"` was deprecated. + .. versionadded:: 1.9 + Quantile/Pinball loss criterion + splitter : {"best", "random"}, default="best" The strategy used to choose the split at each node. Supported strategies are "best" to choose the best split and "random" to choose @@ -1255,6 +1268,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionadded:: 1.4 + quantile : float, default=0.5 + The quantile to predict when ``criterion="quantile"``. It must be strictly + between 0 and 1. If 0.5 (default), the model predicts the median. + Attributes ---------- feature_importances_ : ndarray of shape (n_features,) @@ -1338,9 +1355,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): _parameter_constraints: dict = { **BaseDecisionTree._parameter_constraints, "criterion": [ - StrOptions({"squared_error", "absolute_error", "poisson"}), + StrOptions({"squared_error", "absolute_error", "poisson", "quantile"}), Hidden(Criterion), ], + "quantile": [Interval(RealNotInt, 0.0, 1.0, closed="neither")], } def __init__( @@ -1358,6 +1376,7 @@ def __init__( min_impurity_decrease=0.0, ccp_alpha=0.0, monotonic_cst=None, + quantile=0.5, ): if isinstance(criterion, str) and criterion == "friedman_mse": # TODO(1.11): remove support of "friedman_mse" criterion. @@ -1383,6 +1402,7 @@ def __init__( ccp_alpha=ccp_alpha, monotonic_cst=monotonic_cst, ) + self.quantile = quantile @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y, sample_weight=None, check_input=True): @@ -1767,14 +1787,16 @@ class ExtraTreeRegressor(DecisionTreeRegressor): Parameters ---------- - criterion : {"squared_error", "absolute_error", "poisson"}, default="squared_error" + criterion : {"squared_error", "absolute_error", "quantile", "poisson"}, \ + default="squared_error" The function to measure the quality of a split. Supported criteria are "squared_error" for the mean squared error, which is equal to variance reduction as feature selection criterion and minimizes the L2 loss using the mean of each terminal node, "absolute_error" for the mean absolute error, which minimizes the L1 loss using the median of each terminal - node, and "poisson" which uses reduction in Poisson deviance to find splits, - also using the mean of each terminal node. + node, "quantile" which minimizes the pinball loss using the quantile of each + terminal node (controlled by ``quantile``), and "poisson" which uses reduction + in Poisson deviance to find splits, also using the mean of each terminal node. .. versionadded:: 0.18 Mean Absolute Error (MAE) criterion. @@ -1785,6 +1807,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): .. versionchanged:: 1.9 Criterion `"friedman_mse"` was deprecated. + .. versionadded:: 1.9 + Quantile/Pinball loss criterion + splitter : {"random", "best"}, default="random" The strategy used to choose the split at each node. Supported strategies are "best" to choose the best split and "random" to choose @@ -1897,6 +1922,10 @@ class ExtraTreeRegressor(DecisionTreeRegressor): .. versionadded:: 1.4 + quantile : float, default=0.5 + The quantile to predict when ``criterion="quantile"``. It must be strictly + between 0 and 1. If 0.5 (default), the model predicts the median. + Attributes ---------- max_features_ : int @@ -1981,6 +2010,7 @@ def __init__( max_leaf_nodes=None, ccp_alpha=0.0, monotonic_cst=None, + quantile=0.5, ): super().__init__( criterion=criterion, @@ -1995,6 +2025,7 @@ def __init__( random_state=random_state, ccp_alpha=ccp_alpha, monotonic_cst=monotonic_cst, + quantile=quantile, ) def __sklearn_tags__(self): diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 19c0d9b03c743..bc39c29a354fd 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1175,9 +1175,9 @@ cdef class MSE(RegressionCriterion): impurity_right[0] /= self.n_outputs -# Helper for MAE criterion: +# Helper for Pinball criterion: -cdef void precompute_absolute_errors( +cdef void precompute_pinball_losses( const float64_t[::1] sorted_y, const intp_t[::1] ranks, const float64_t[:] sample_weight, @@ -1185,19 +1185,20 @@ cdef void precompute_absolute_errors( WeightedFenwickTree tree, intp_t start, intp_t end, - float64_t[::1] abs_errors, - float64_t[::1] medians, + float64_t alpha, + float64_t[::1] pinball_losses, + float64_t[::1] quantiles, ) noexcept nogil: """ - Fill `abs_errors` and `medians`. + Fill `pinball_losses` and `quantiles`. If start < end: - Forward pass: Computes the "prefix" AEs/medians - i.e the AEs for each set of indices sample_indices[start:start + i] + Forward pass: Computes the "prefix" losses/quantiles + i.e the losses for each set of indices sample_indices[start:start + i] with i in {1, ..., n}, where n = end - start. Else: - Backward pass: Computes the "suffix" AEs/medians - i.e the AEs for each set of indices sample_indices[start - i:start] + Backward pass: Computes the "suffix" losses/quantiles + i.e the losses for each set of indices sample_indices[start - i:start] with i in {1, ..., n}, where n = start - end. Parameters @@ -1217,18 +1218,20 @@ cdef void precompute_absolute_errors( Start index in `sample_indices` end : intp_t End index (exclusive) in `sample_indices` - abs_errors : float64_t[::1] - array to store (increment) the computed absolute errors. Shape: (n,) + alpha : float64_t + Quantile level for the pinball loss (between 0 and 1) + pinball_losses : float64_t[::1] + array to store (increment) the computed pinball losses. Shape: (n,) with n := end - start - medians : float64_t[::1] - array to store (overwrite) the computed medians. Shape: (n,) + quantiles : float64_t[::1] + array to store (overwrite) the computed quantiles. Shape: (n,) Complexity: O(n log n) """ cdef: - intp_t p, i, step, n, rank, median_rank, median_prev_rank + intp_t p, i, step, n, rank, quantile_rank, quantile_prev_rank float64_t w = 1. - float64_t half_weight, median + float64_t target_weight, quantile float64_t w_right, w_left, wy_left, wy_right if start < end: @@ -1251,37 +1254,38 @@ cdef void precompute_absolute_errors( rank = ranks[p] tree.add(rank, sorted_y[rank], w) - # Weighted median by cumulative weight: the median is where the - # cumulative weight crosses half of the total weight. - half_weight = 0.5 * tree.total_w - # find the smallest activated rank with cumulative weight > half_weight + # Weighted quantile by cumulative weight: the quantile is where the + # cumulative weight crosses alpha of the total weight. + target_weight = alpha * tree.total_w + # find the smallest activated rank with cumulative weight > target_weight # while returning the prefix sums (`w_left` and `wy_left`) # up to (and excluding) that index: - median_rank = tree.search(half_weight, &w_left, &wy_left, &median_prev_rank) - - if median_rank != median_prev_rank: - # Exact match for half_weight fell between two consecutive ranks: - # cumulative weight up to `median_rank` excluded is exactly half_weight. - # In that case, `median_prev_rank` is the activated rank such that - # the cumulative weight up to it included is exactly half_weight. - # In this case we take the mid-point: - median = (sorted_y[median_prev_rank] + sorted_y[median_rank]) / 2 + quantile_rank = tree.search(target_weight, &w_left, &wy_left, &quantile_prev_rank) + + if quantile_rank != quantile_prev_rank: + # Exact match for target_weight fell between two consecutive ranks: + # cumulative weight up to `quantile_rank` excluded is exactly target_weight. + # In that case, `quantile_prev_rank` is the activated rank such that + # the cumulative weight up to it included is exactly target_weight. + # In this case we take the mid-point to match with + # sklearn.utils.stats._weighted_percentile(..., average=True) + quantile = (sorted_y[quantile_prev_rank] + sorted_y[quantile_rank]) / 2 else: - # if there are no exact match for half_weight in the cumulative weights - # `median_rank == median_prev_rank` and the median is: - median = sorted_y[median_rank] + # if there are no exact match for target_weight in the cumulative weights + # `quantile_rank == quantile_prev_rank` and the quantile is: + quantile = sorted_y[quantile_rank] # Convert left prefix sums into right-hand complements. w_right = tree.total_w - w_left wy_right = tree.total_wy - wy_left - medians[p] = median - # Pinball-loss identity for absolute error at the current set: - # sum_{y_i >= m} w_i (y_i - m) = wy_right - m * w_right - # sum_{y_i < m} w_i (m - y_i) = m * w_left - wy_left - abs_errors[p] += ( - (wy_right - median * w_right) - + (median * w_left - wy_left) + quantiles[p] = quantile + # Pinball loss identity for the alpha-quantile at the current set: + # sum_{y_i >= q} w_i * alpha * (y_i - q) = alpha * (wy_right - q * w_right) + # sum_{y_i < q} w_i * (1-alpha) * (q - y_i) = (1-alpha) * (q * w_left - wy_left) + pinball_losses[p] += ( + alpha * (wy_right - quantile * w_right) + + (1.0 - alpha) * (quantile * w_left - wy_left) ) p += step @@ -1301,15 +1305,16 @@ cdef inline void compute_ranks( ranks[sorted_indices[i]] = i -def _py_precompute_absolute_errors( +def _py_precompute_pinball_losses( const float64_t[:, ::1] ys, const float64_t[:] sample_weight, const intp_t[:] sample_indices, const intp_t start, const intp_t end, const intp_t n, + const float64_t alpha, ): - """Used for testing precompute_absolute_errors.""" + """Used for testing precompute_pinball_losses.""" cdef: intp_t p, i intp_t s = start @@ -1318,8 +1323,8 @@ def _py_precompute_absolute_errors( float64_t[::1] sorted_y = np.empty(n, dtype=np.float64) intp_t[::1] sorted_indices = np.empty(n, dtype=np.intp) intp_t[::1] ranks = np.empty(n, dtype=np.intp) - float64_t[::1] abs_errors = np.zeros(n, dtype=np.float64) - float64_t[::1] medians = np.empty(n, dtype=np.float64) + float64_t[::1] pinball_losses = np.zeros(n, dtype=np.float64) + float64_t[::1] quantiles = np.empty(n, dtype=np.float64) if start > end: s = end + 1 @@ -1329,28 +1334,30 @@ def _py_precompute_absolute_errors( sorted_y[p - s] = ys[i, 0] compute_ranks(&sorted_y[0], &sorted_indices[0], &ranks[s], n) - precompute_absolute_errors( + precompute_pinball_losses( sorted_y, ranks, sample_weight, sample_indices, tree, - start, end, abs_errors, medians + start, end, alpha, pinball_losses, quantiles ) - return np.asarray(abs_errors)[s:e], np.asarray(medians)[s:e] + return np.asarray(pinball_losses)[s:e], np.asarray(quantiles)[s:e] -cdef class MAE(Criterion): - r"""Mean absolute error impurity criterion. +cdef class Pinball(Criterion): + r"""Pinball loss impurity criterion. - It has almost nothing in common with other regression criterions - so it doesn't inherit from RegressionCriterion. + This criterion generalizes the Mean Absolute Error (MAE) by using + quantile regression with a specified quantile level (alpha). + MAE corresponds to alpha=0.5 (median). - MAE = (1 / n)*(\sum_i |y_i - p_i|), where y_i is the true - value and p_i is the predicted value. - In a decision tree, that prediction is the (weighted) median + Pinball loss = (1 / n)*(\sum_i rho_alpha(y_i - q_i)), where y_i is the true + value, q_i is the predicted quantile, and rho_alpha is the pinball loss function: + rho_alpha(u) = u * (alpha - I(u < 0)) + In a decision tree, that prediction is the (weighted) alpha-quantile of the targets in the node. How this implementation works ----------------------------- This class precomputes in `reset`, for the current node, - the absolute-error values and corresponding medians for all + the pinball loss values and corresponding quantiles for all potential split positions: every p in [start, end). For that: @@ -1360,17 +1367,17 @@ cdef class MAE(Criterion): * "activate" one sample at a time at its rank within a prefix sum tree, the `WeightedFenwickTree`: `tree.add(rank, y, weight)` The tree maintains cumulative sums of weights and of `weight * y` - * search for the half total weight in the tree: - `tree.search(current_total_weight / 2)`. + * search for the target weight alpha * total_weight in the tree: + `tree.search(alpha * current_total_weight)`. This allows us to retrieve/compute: - * the current weighted median value - * the absolute-error contribution via the standard pinball-loss identity: - AE = (wy_right - median * w_right) + (median * w_left - wy_left) + * the current weighted quantile value + * the pinball loss contribution via the standard pinball-loss identity: + PL = alpha * (wy_right - quantile * w_right) + (1-alpha) * (quantile * w_left - wy_left) - We perform two such passes: - * one forward from `start` to `end - 1` to fill `left_abs_errors[p]` and - `left_medians[p]` for left children. + * one forward from `start` to `end - 1` to fill `left_pinball_losses[p]` and + `left_quantiles[p]` for left children. * one backward from `end - 1` down to `start` to fill - `right_abs_errors[p]` and `right_medians[p]` for right children. + `right_pinball_losses[p]` and `right_quantiles[p]` for right children. Complexity: time complexity is O(n log n), indeed: - computing ranks is based on sorting: O(n log n) @@ -1380,38 +1387,39 @@ cdef class MAE(Criterion): How the other methods use the precomputations -------------------------------------------- - `reset` performs the precomputation described above. - It also stores the node weighted median per output in - `node_medians` (prediction value of the node). + It also stores the node weighted quantile per output in + `node_quantiles` (prediction value of the node). - `update(new_pos)` only updates `weighted_n_left` and `weighted_n_right`; - no recomputation of errors is needed. + no recomputation of losses is needed. - - `children_impurity` reads the precomputed absolute errors at - `left_abs_errors[pos - 1]` and `right_abs_errors[pos]` and scales + - `children_impurity` reads the precomputed pinball losses at + `left_pinball_losses[pos - 1]` and `right_pinball_losses[pos]` and scales them by the corresponding child weights and `n_outputs` to report the impurity of each child. - `middle_value` and `check_monotonicity` use the precomputed - `left_medians[pos - 1]` and `right_medians[pos]` to derive the + `left_quantiles[pos - 1]` and `right_quantiles[pos]` to derive the mid-point value and to validate monotonic constraints when enabled. - - Missing values are not supported for MAE: `init_missing` raises. + - Missing values are not supported for Pinball: `init_missing` raises. For a complementary, in-depth discussion of the mathematics and design choices, see the external report: https://github.com/cakedev0/fast-mae-split/blob/main/report.ipynb """ - cdef float64_t[::1] node_medians - cdef float64_t[::1] left_abs_errors - cdef float64_t[::1] right_abs_errors - cdef float64_t[::1] left_medians - cdef float64_t[::1] right_medians + cdef float64_t alpha + cdef float64_t[::1] node_quantiles + cdef float64_t[::1] left_pinball_losses + cdef float64_t[::1] right_pinball_losses + cdef float64_t[::1] left_quantiles + cdef float64_t[::1] right_quantiles cdef float64_t[::1] sorted_y cdef intp_t [::1] sorted_indices cdef intp_t[::1] ranks cdef WeightedFenwickTree prefix_sum_tree - def __cinit__(self, intp_t n_outputs, intp_t n_samples): + def __cinit__(self, intp_t n_outputs, intp_t n_samples, float64_t alpha=0.5): """Initialize parameters for this criterion. Parameters @@ -1422,6 +1430,8 @@ cdef class MAE(Criterion): n_samples : intp_t The total number of samples to fit on """ + self.alpha = alpha + # Default values self.start = 0 self.pos = 0 @@ -1434,14 +1444,14 @@ cdef class MAE(Criterion): self.weighted_n_left = 0.0 self.weighted_n_right = 0.0 - self.node_medians = np.zeros(n_outputs, dtype=np.float64) + self.node_quantiles = np.zeros(n_outputs, dtype=np.float64) # Note: this criterion has a n_samples x 64 bytes memory footprint, which is # fine as it's instantiated only once to build an entire tree - self.left_abs_errors = np.empty(n_samples, dtype=np.float64) - self.right_abs_errors = np.empty(n_samples, dtype=np.float64) - self.left_medians = np.empty(n_samples, dtype=np.float64) - self.right_medians = np.empty(n_samples, dtype=np.float64) + self.left_pinball_losses = np.empty(n_samples, dtype=np.float64) + self.right_pinball_losses = np.empty(n_samples, dtype=np.float64) + self.left_quantiles = np.empty(n_samples, dtype=np.float64) + self.right_quantiles = np.empty(n_samples, dtype=np.float64) self.ranks = np.empty(n_samples, dtype=np.intp) # Important: The arrays declared above are indexed with # the absolute position `p` in `sample_indices` (not with a 0-based offset). @@ -1474,13 +1484,12 @@ cdef class MAE(Criterion): sample_indices[start:start] and sample_indices[start:end]. WARNING: sample_indices will be modified in-place externally - after this method is called. + after this method is called """ cdef: intp_t i, p intp_t n = end - start float64_t w = 1.0 - # Initialize fields self.y = y self.sample_weight = sample_weight @@ -1506,7 +1515,7 @@ cdef class MAE(Criterion): if n_missing == 0: return with gil: - raise ValueError("missing values is not supported for MAE.") + raise ValueError("missing values is not supported for Pinball criterion.") cdef int reset(self) except -1 nogil: """Reset the criterion at pos=start. @@ -1516,7 +1525,7 @@ cdef class MAE(Criterion): Reset might be called after an external class has changed inplace self.sample_indices[start:end], hence re-computing - the absolute errors is needed. + the pinball loss is needed. """ cdef intp_t k, p, i @@ -1525,15 +1534,23 @@ cdef class MAE(Criterion): self.pos = self.start n_bytes = self.n_node_samples * sizeof(float64_t) - memset(&self.left_abs_errors[self.start], 0, n_bytes) - memset(&self.right_abs_errors[self.start], 0, n_bytes) + memset(&self.left_pinball_losses[self.start], 0, n_bytes) + memset(&self.right_pinball_losses[self.start], 0, n_bytes) # Multi-output handling: - # absolute errors are accumulated across outputs by - # incrementing `left_abs_errors` and `right_abs_errors` on each pass. - # The per-output medians arrays are overwritten at each output iteration + # pinball losses are accumulated across outputs by + # incrementing `left_pinball_losses` and `right_pinball_losses` on each pass. + # The per-output quantiles arrays are overwritten at each output iteration # as they are only used for monotonicity checks when `n_outputs == 1`. + # Precompute pinball losses (summed over each output) + # and quantiles (used only when n_outputs=1) + # of the right and left child of all possible splits + # for the current ordering of `sample_indices` + # Precomputation is needed here and can't be done step-by-step in the update method + # like for other criterions. Indeed, we don't have efficient ways to update right child + # statistics when removing samples from it. So we compute right child losses/quantiles by + # traversing from right to left (and hence only adding samples). for k in range(self.n_outputs): # 1) Node-local ordering: @@ -1554,28 +1571,28 @@ cdef class MAE(Criterion): ) # 2) Forward pass - # from `start` to `end - 1` to fill `left_abs_errors[p]` and - # `left_medians[p]` for left children. - precompute_absolute_errors( + # from `start` to `end - 1` to fill `left_pinball_losses[p]` and + # `left_quantiles[p]` for left children. + precompute_pinball_losses( self.sorted_y, self.ranks, self.sample_weight, self.sample_indices, - self.prefix_sum_tree, self.start, self.end, - # left_abs_errors is incremented, left_medians is overwritten - self.left_abs_errors, self.left_medians + self.prefix_sum_tree, self.start, self.end, self.alpha, + # left_pinball_losses is incremented, left_quantiles is overwritten + self.left_pinball_losses, self.left_quantiles ) # 3) Backward pass - # from `end - 1` down to `start` to fill `right_abs_errors[p]` - # and `right_medians[p]` for right children. - precompute_absolute_errors( + # from `end - 1` down to `start` to fill `right_pinball_losses[p]` + # and `right_quantiles[p]` for right children. + precompute_pinball_losses( self.sorted_y, self.ranks, self.sample_weight, self.sample_indices, - self.prefix_sum_tree, self.end - 1, self.start - 1, - # right_abs_errors is incremented, right_medians is overwritten - self.right_abs_errors, self.right_medians + self.prefix_sum_tree, self.end - 1, self.start - 1, self.alpha, + # right_pinball_losses is incremented, right_quantiles is overwritten + self.right_pinball_losses, self.right_quantiles ) - # Store the median for the current node: when p == self.start all the + # Store the quantile for the current node: when p == self.start all the # node's data points are sent to the right child, so the current node - # median value and the right child median value would be equal. - self.node_medians[k] = self.right_medians[self.start] + # quantile value and the right child quantile value would be equal. + self.node_quantiles[k] = self.right_quantiles[self.start] return 0 @@ -1612,7 +1629,7 @@ cdef class MAE(Criterion): """Computes the node value of sample_indices[start:end] into dest.""" cdef intp_t k for k in range(self.n_outputs): - dest[k] = self.node_medians[k] + dest[k] = self.node_quantiles[k] cdef inline float64_t middle_value(self) noexcept nogil: """Compute the middle value of a split for monotonicity constraints as the simple average @@ -1622,8 +1639,8 @@ cdef class MAE(Criterion): n_outputs == 1. """ return ( - self.left_medians[self.pos - 1] - + self.right_medians[self.pos] + self.left_quantiles[self.pos - 1] + + self.right_quantiles[self.pos] ) / 2 cdef inline bint check_monotonicity( @@ -1635,19 +1652,19 @@ cdef class MAE(Criterion): """Check monotonicity constraint is satisfied at the current regression split""" return self._check_monotonicity( monotonic_cst, lower_bound, upper_bound, - self.left_medians[self.pos - 1], self.right_medians[self.pos]) + self.left_quantiles[self.pos - 1], self.right_quantiles[self.pos]) cdef float64_t node_impurity(self) noexcept nogil: """Evaluate the impurity of the current node. - Evaluate the MAE criterion as impurity of the current node, + Evaluate the Pinball criterion as impurity of the current node, i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. Time complexity: O(1) (precomputed in `.reset()`) """ return ( - self.right_abs_errors[0] + self.right_pinball_losses[self.start] / (self.weighted_n_node_samples * self.n_outputs) ) @@ -1665,20 +1682,20 @@ cdef class MAE(Criterion): # if pos == start, left child is empty, hence impurity is 0 if self.pos > self.start: - impurity_left += self.left_abs_errors[self.pos - 1] + impurity_left += self.left_pinball_losses[self.pos - 1] p_impurity_left[0] = impurity_left / (self.weighted_n_left * self.n_outputs) # if pos == end, right child is empty, hence impurity is 0 if self.pos < self.end: - impurity_right += self.right_abs_errors[self.pos] + impurity_right += self.right_pinball_losses[self.pos] p_impurity_right[0] = impurity_right / (self.weighted_n_right * self.n_outputs) - # those 2 methods are copied from the RegressionCriterion abstract class: def __reduce__(self): - return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) + return (type(self), (self.n_outputs, self.n_samples, self.alpha), self.__getstate__()) + # this method is copied from the RegressionCriterion abstract class: cdef inline void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil: """Clip the value in dest between lower_bound and upper_bound for monotonic constraints.""" if dest[0] < lower_bound: @@ -1687,6 +1704,27 @@ cdef class MAE(Criterion): dest[0] = upper_bound +cdef class MAE(Pinball): + """ + The median is just the quantile alpha=0.5 + And the absolute error is twice the pinball_loss (with alpha=0.5) + """ + + # XXX: Trust the instanciater to pass alpha=0.5 to the __cinit__... + + cdef float64_t node_impurity(self) noexcept nogil: + return 2 * Pinball.node_impurity(self) + + cdef void children_impurity(self, float64_t* p_impurity_left, + float64_t* p_impurity_right) noexcept nogil: + Pinball.children_impurity(self, p_impurity_left, p_impurity_right) + p_impurity_left[0] *= 2 + p_impurity_right[0] *= 2 + + def __reduce__(self): + return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) + + cdef class Poisson(RegressionCriterion): """Half Poisson deviance as impurity criterion. diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index beca79e3c18f8..0fa6ce069fcf0 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -23,6 +23,7 @@ from sklearn.metrics import ( accuracy_score, mean_absolute_error, + mean_pinball_loss, mean_poisson_deviance, mean_squared_error, ) @@ -41,7 +42,7 @@ DENSE_SPLITTERS, SPARSE_SPLITTERS, ) -from sklearn.tree._criterion import _py_precompute_absolute_errors +from sklearn.tree._criterion import _py_precompute_pinball_losses from sklearn.tree._partitioner import _py_sort from sklearn.tree._tree import ( NODE_DTYPE, @@ -1858,13 +1859,15 @@ def _pickle_copy(obj): assert n_outputs == n_outputs_ assert_array_equal(n_classes, n_classes_) - for _, typename in CRITERIA_REG.items(): - criteria = typename(n_outputs, n_samples) + for name, typename in CRITERIA_REG.items(): + args = (n_outputs, n_samples) + if name == "quantile": + args = (*args, 0.5) + criteria = typename(*args) result = copy_func(criteria).__reduce__() - typename_, (n_outputs_, n_samples_), _ = result + typename_, args_, _ = result assert typename == typename_ - assert n_outputs == n_outputs_ - assert n_samples == n_samples_ + assert args == args_ @pytest.mark.parametrize("sparse_container", [None] + CSC_CONTAINERS) @@ -2929,7 +2932,8 @@ def test_sort_log2_build(): assert_array_equal(samples, expected_samples) -def test_absolute_errors_precomputation_function(global_random_seed): +@pytest.mark.parametrize("alpha", [0.5, 0.1, 0.75]) +def test_pinball_loss_precomputation_function(alpha, global_random_seed): """ Test the main bit of logic of the MAE(RegressionCriterion) class (used by DecisionTreeRegressor(criterion="absolute_error")). @@ -2940,33 +2944,41 @@ def test_absolute_errors_precomputation_function(global_random_seed): it can be safely removed. """ - def compute_prefix_abs_errors_naive(y, w): + def compute_prefix_losses_naive(y, w): + """ + Computes the pinball loss for all (y[:i], w[:i]) + Naive: O(n^2 log n) + """ y = y.ravel().copy() - medians = [ - _weighted_percentile(y[:i], w[:i], 50, average=True) + quantiles = [ + _weighted_percentile(y[:i], w[:i], alpha * 100, average=True) for i in range(1, y.size + 1) ] - errors = [ - (np.abs(y[:i] - m) * w[:i]).sum() - for i, m in zip(range(1, y.size + 1), medians) + losses = [ + mean_pinball_loss( + y[:i], np.full(i, quantile), sample_weight=w[:i], alpha=alpha + ) + * w[:i].sum() + for i, quantile in zip(range(1, y.size + 1), quantiles) ] - return np.array(errors), np.array(medians) + return np.array(losses), np.array(quantiles) def assert_same_results(y, w, indices, reverse=False): - n = y.shape[0] args = (n - 1, -1) if reverse else (0, n) - abs_errors, medians = _py_precompute_absolute_errors(y, w, indices, *args, n) + losses, quantiles = _py_precompute_pinball_losses( + y, w, indices, *args, n, alpha=alpha + ) y_sorted = y[indices] w_sorted = w[indices] if reverse: y_sorted = y_sorted[::-1] w_sorted = w_sorted[::-1] - abs_errors_, medians_ = compute_prefix_abs_errors_naive(y_sorted, w_sorted) + losses_, quantiles_ = compute_prefix_losses_naive(y_sorted, w_sorted) if reverse: - abs_errors_ = abs_errors_[::-1] - medians_ = medians_[::-1] - assert_allclose(abs_errors, abs_errors_, atol=1e-11) - assert_allclose(medians, medians_, atol=1e-11) + losses_ = losses_[::-1] + quantiles_ = quantiles_[::-1] + assert_allclose(losses, losses_, atol=1e-11) + assert_allclose(quantiles, quantiles_, atol=1e-11) rng = np.random.default_rng(global_random_seed) @@ -2984,10 +2996,22 @@ def assert_same_results(y, w, indices, reverse=False): assert_same_results(y, w, indices, reverse=True) -def test_absolute_error_accurately_predicts_weighted_median(global_random_seed): +@pytest.mark.parametrize( + "criterion, quantile", + [ + ("absolute_error", 0.5), + ("quantile", 0.3), + ("quantile", 0.5), + ("quantile", 0.9), + ], +) +def test_quantile_criterion_predicts_weighted_quantile( + criterion, quantile, global_random_seed +): """ - Test that the weighted-median computed under-the-hood when - building a tree with criterion="absolute_error" is correct. + Test that the weighted quantile computed under-the-hood when building a tree + with criterion="quantile" is correct. The absolute error criterion is the + special case with quantile=0.5. """ rng = np.random.default_rng(global_random_seed) n = int(1e5) @@ -2995,14 +3019,45 @@ def test_absolute_error_accurately_predicts_weighted_median(global_random_seed): # Large number of zeros and otherwise continuous weights: weights = rng.integers(0, 3, size=n) * rng.uniform(0, 1, size=n) - tree_leaf_weighted_median = ( - DecisionTreeRegressor(criterion="absolute_error", max_depth=1) + tree_leaf_weighted_quantile = ( + DecisionTreeRegressor(criterion=criterion, quantile=quantile, max_depth=1) .fit(np.ones(shape=(data.shape[0], 1)), data, sample_weight=weights) .tree_.value.ravel()[0] ) - weighted_median = _weighted_percentile(data, weights, 50, average=True) + weighted_quantile = _weighted_percentile( + data, weights, quantile * 100, average=True + ) + + assert_allclose(tree_leaf_weighted_quantile, weighted_quantile) + + +def test_quantile_confidence_interval_coverage(global_random_seed): + """ + Test that quantile regression confidence intervals have appropriate coverage. + """ + rng = np.random.default_rng(global_random_seed) + n_samples = 2000 + X = rng.uniform(0.0, 1.0, size=(n_samples, 1)) + y = np.sin(2 * np.pi * X[:, 0]) + + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.4, random_state=0 + ) + common_params = dict( + criterion="quantile", max_depth=6, min_samples_leaf=30, random_state=0 + ) + + lower = DecisionTreeRegressor(**common_params, quantile=0.1) + lower.fit(X_train, y_train) + upper = DecisionTreeRegressor(**common_params, quantile=0.9) + upper.fit(X_train, y_train) + + lower_pred = lower.predict(X_test) + upper_pred = upper.predict(X_test) + coverage = np.mean((y_test >= lower_pred) & (y_test <= upper_pred)) - assert_allclose(tree_leaf_weighted_median, weighted_median) + # Trees are flexible but still approximate; allow some slack around 80%. + assert 0.75 <= coverage <= 0.85 def test_splitting_with_missing_values():