diff --git a/legateboost/callbacks.py b/legateboost/callbacks.py index bacf2693..e17b26dc 100644 --- a/legateboost/callbacks.py +++ b/legateboost/callbacks.py @@ -57,8 +57,8 @@ def after_iteration( class EarlyStopping(TrainingCallback): - """Callback for early stopping during training. The last evaluation dataset - is used for early stopping. + """Callback for early stopping during training. The last evaluation dataset is + used for early stopping. Args: rounds (int): The number of rounds to wait for improvement before stopping. diff --git a/legateboost/legateboost.py b/legateboost/legateboost.py index cd983e46..43b92787 100644 --- a/legateboost/legateboost.py +++ b/legateboost/legateboost.py @@ -194,11 +194,7 @@ def _get_weighted_gradient( learning_rate: float, ) -> Tuple[cn.ndarray, cn.ndarray]: """Computes the weighted gradient and Hessian for the given predictions - and labels. - - Also applies a pre-rounding step to ensure reproducible floating - point summation. - """ + and labels.""" # check input dimensions are consistent assert y.ndim == pred.ndim == 2, (y.shape, pred.shape) g, h = self._objective_instance.gradient( @@ -317,8 +313,8 @@ def update( eval_result: EvalResult = {}, ) -> Self: """Update a gradient boosting model from the training set (X, y). This - method does not add any new models to the ensemble, only updates - existing models to fit the new data. + method does not add any new models to the ensemble, only updates existing + models to fit the new data. Parameters ---------- @@ -477,8 +473,8 @@ def __iter__(self) -> Any: return iter(self.models_) def __mul__(self, scalar: Any) -> Self: - """Gradient boosted models are linear in the predictions before the - non-linear link function is applied. This means that the model can be + """Gradient boosted models are linear in the predictions before the non- + linear link function is applied. This means that the model can be multiplied by a scalar, which subsequently scales all raw output predictions. This is useful for ensembling models. @@ -550,8 +546,8 @@ def global_attributions( n_samples: int = 5, check_efficiency: bool = False, ) -> Tuple[cn.array, cn.array]: - r"""Compute global feature attributions for the model. Global - attributions show the effect of a feature on a model's loss function. + r"""Compute global feature attributions for the model. Global attributions + show the effect of a feature on a model's loss function. We use a Shapley value approach to compute the attributions: :math:`Sh_i(v)=\frac{1}{|N|!} \sum_{\sigma \in \mathfrak{S}_d} \big[ v([\sigma]_{i-1} \cup\{i\}) - v([\sigma]_{i-1}) \big],` @@ -612,11 +608,10 @@ def local_attributions( n_samples: int = 5, check_efficiency: bool = False, ) -> Tuple[cn.array, cn.array]: - r"""Local feature attributions for model predictions. Shows the effect - of a feature on each output prediction. See the definition of Shapley - values in :func:`~legateboost.BaseModel.global_attributions`, where the - :math:`v` function is here the model prediction instead of the loss - function. + r"""Local feature attributions for model predictions. Shows the effect of a + feature on each output prediction. See the definition of Shapley values in + :func:`~legateboost.BaseModel.global_attributions`, where the :math:`v` + function is here the model prediction instead of the loss function. Parameters ---------- @@ -750,8 +745,8 @@ def partial_fit( eval_set: List[Tuple[cn.ndarray, ...]] = [], eval_result: EvalResult = {}, ) -> LBBase: - """This method is used for incremental (online) training of the model. - An additional `n_estimators` models will be added to the ensemble. + """This method is used for incremental (online) training of the model. An + additional `n_estimators` models will be added to the ensemble. Parameters ---------- @@ -928,8 +923,8 @@ def partial_fit( eval_result: EvalResult = {}, ) -> LBBase: """This method is used for incremental fitting on a batch of samples. - Requires the classes to be provided up front, as they may not be - inferred from the first batch. + Requires the classes to be provided up front, as they may not be inferred + from the first batch. Parameters ---------- @@ -1033,8 +1028,8 @@ def fit( return self def predict_raw(self, X: cn.ndarray) -> cn.ndarray: - """Predict pre-transformed values for samples in X. E.g. before - applying a sigmoid function. + """Predict pre-transformed values for samples in X. E.g. before applying a + sigmoid function. Parameters ---------- @@ -1064,7 +1059,7 @@ def predict_proba(self, X: cn.ndarray) -> cn.ndarray: Returns ------- - y : + probabilities: The predicted class probabilities for each sample in X. """ X = _lb_check_X(X) @@ -1091,4 +1086,5 @@ def predict(self, X: cn.ndarray) -> cn.ndarray: y : The predicted class labels for each sample in X. """ - return cn.argmax(self.predict_proba(X), axis=1) + check_is_fitted(self) + return self._objective_instance.output_class(self.predict_proba(X)) diff --git a/legateboost/metrics.py b/legateboost/metrics.py index 98c82154..f0e0a7f9 100644 --- a/legateboost/metrics.py +++ b/legateboost/metrics.py @@ -17,6 +17,7 @@ "GammaDevianceMetric", "QuantileMetric", "LogLossMetric", + "MultiLabelMetric", "ExponentialMetric", ] @@ -144,8 +145,8 @@ def name(self) -> str: class GammaLLMetric(BaseMetric): - """The mean negative log likelihood of the labels, given parameters - predicted by the model.""" + """The mean negative log likelihood of the labels, given parameters predicted + by the model.""" @override def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> cn.ndarray: @@ -252,8 +253,8 @@ def name(self) -> str: class LogLossMetric(BaseMetric): - """Class for computing the logarithmic loss (logloss) metric between the - true labels and predicted labels. + """Class for computing the logarithmic loss (logloss) metric between the true + labels and predicted labels. For binary classification: @@ -273,7 +274,7 @@ class LogLossMetric(BaseMetric): def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> cn.ndarray: y = y.squeeze() eps = cn.finfo(pred.dtype).eps - cn.clip(pred, eps, 1 - eps, out=pred) + pred = cn.clip(pred, eps, 1 - eps) w_sum = w.sum() @@ -296,6 +297,26 @@ def name(self) -> str: return "log_loss" +class MultiLabelMetric(BaseMetric): + """Multi-label metric is a binary log-loss metric averaged over multiple + labels. + + See also: + :class:`legateboost.objectives.MultiLabelObjective` + """ # noqa: E501 + + def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> cn.ndarray: + y = y.squeeze() + eps = cn.finfo(pred.dtype).eps + pred = cn.clip(pred, eps, 1 - eps) + w_sum = w.sum() + logloss = -(y * cn.log(pred) + (self.one - y) * cn.log(self.one - pred)) + return (logloss * w[:, cn.newaxis]).sum() / w_sum + + def name(self) -> str: + return "multi_label" + + class ExponentialMetric(BaseMetric): """Class for computing the exponential loss metric. diff --git a/legateboost/models/base_model.py b/legateboost/models/base_model.py index d6e6e416..2a5df986 100644 --- a/legateboost/models/base_model.py +++ b/legateboost/models/base_model.py @@ -11,9 +11,9 @@ class BaseModel(PickleCupynumericMixin, ABC): """Base class for all models in LegateBoost. - Defines the interface for fitting, updating, and predicting a model, - as well as string representation and equality comparison. Implement - these methods to create a custom model. + Defines the interface for fitting, updating, and predicting a model, as well + as string representation and equality comparison. Implement these methods to + create a custom model. """ def set_random_state(self, random_state: np.random.RandomState) -> "BaseModel": @@ -27,8 +27,7 @@ def fit( g: cn.ndarray, h: cn.ndarray, ) -> "BaseModel": - """Fit the model to a second order Taylor expansion of the loss - function. + """Fit the model to a second order Taylor expansion of the loss function. Parameters ---------- diff --git a/legateboost/models/krr.py b/legateboost/models/krr.py index 65249fd0..16d7f445 100644 --- a/legateboost/models/krr.py +++ b/legateboost/models/krr.py @@ -35,9 +35,9 @@ def rbf(x: cn.ndarray, sigma: float) -> cn.ndarray: class KRR(BaseModel): - """Kernel Ridge Regression model using the Nyström approximation. The - accuracy of the approximation is governed by the parameter `n_components` - <= `n`. Effectively, `n_components` rows will be randomly sampled (without + """Kernel Ridge Regression model using the Nyström approximation. The accuracy + of the approximation is governed by the parameter `n_components` <= `n`. + Effectively, `n_components` rows will be randomly sampled (without replacement) from X in each boosting iteration. The kernel is fixed to be the RBF kernel: diff --git a/legateboost/models/linear.py b/legateboost/models/linear.py index d660bc23..aad65c83 100644 --- a/legateboost/models/linear.py +++ b/legateboost/models/linear.py @@ -9,11 +9,11 @@ class Linear(BaseModel): - """Generalised linear model. Boosting linear models is equivalent to - fitting a single linear model where each boosting iteration is a newton - step. Note that the l2 penalty is applied to the weights of each model, as - opposed to the sum of all models. This can lead to different results when - compared to fitting a linear model with sklearn. + """Generalised linear model. Boosting linear models is equivalent to fitting a + single linear model where each boosting iteration is a newton step. Note that + the l2 penalty is applied to the weights of each model, as opposed to the sum + of all models. This can lead to different results when compared to fitting a + linear model with sklearn. It is recommended to normalize the data before fitting. This ensures regularisation is evenly applied to all features and prevents numerical issues. diff --git a/legateboost/objectives.py b/legateboost/objectives.py index b4ffdaf7..ea389660 100644 --- a/legateboost/objectives.py +++ b/legateboost/objectives.py @@ -14,6 +14,7 @@ GammaLLMetric, LogLossMetric, MSEMetric, + MultiLabelMetric, NormalLLMetric, QuantileMetric, ) @@ -26,6 +27,7 @@ "SquaredErrorObjective", "NormalObjective", "LogLossObjective", + "MultiLabelObjective", "ExponentialObjective", "QuantileObjective", "GammaDevianceObjective", @@ -53,7 +55,7 @@ def gradient(self, y: cn.ndarray, pred: cn.ndarray) -> GradPair: Returns: The functional gradient and hessian of the squared error - objective function. + objective function, both of which must be 2D arrays. """ # noqa: E501 pass @@ -61,10 +63,10 @@ def transform(self, pred: cn.ndarray) -> cn.ndarray: """Transforms the predicted labels. E.g. sigmoid for log loss. Args: - pred : The predicted labels. + pred : Raw predictions. Returns: - The transformed labels. + n-d array. For classification problems outputs a probability. """ return pred @@ -81,7 +83,9 @@ def metric(self) -> BaseMetric: def initialise_prediction( self, y: cn.ndarray, w: cn.ndarray, boost_from_average: bool ) -> cn.ndarray: - """Initializes the base score of the model. May also validate labels. + """Initializes the base score of the model, optionally either to a + baseline value or some value minimising the objective. Should also + validate labels i.e. check if y is suitable for this objective. Args: y : The target values. @@ -90,11 +94,28 @@ def initialise_prediction( from the average of the target values. Returns: - The initial predictions for a single example. + The initial (untransformed) prediction for a single example. """ pass +class ClassificationObjective(BaseObjective): + """Extension of BaseObjective for classification problems, use can optionaly + define a method of extracting a class output from probabilities.""" + + def output_class(self, pred: cn.ndarray) -> cn.ndarray: + """Defined how to output class labels from transfored output. This may be + as simple as argmax over probabilities. + + Args: + pred (cn.ndarray): The transformed predictions. + + Returns: + cn.ndarray: The class labels as a NumPy array. + """ + return cn.argmax(pred, axis=-1) + + class SquaredErrorObjective(BaseObjective): """The Squared Error objective function for regression problems. @@ -318,8 +339,8 @@ def initialise_prediction( class GammaObjective(FitInterceptRegMixIn, Forecast): - """Regression with the :math:`\\Gamma` distribution function using the - shape scale parameterization.""" + """Regression with the :math:`\\Gamma` distribution function using the shape + scale parameterization.""" @override def gradient(self, y: cn.ndarray, pred: cn.ndarray) -> GradPair: @@ -398,8 +419,7 @@ def var(self, param: cn.ndarray) -> cn.ndarray: class QuantileObjective(BaseObjective): - """Minimises the quantile loss, otherwise known as check loss or pinball - loss. + """Minimises the quantile loss, otherwise known as check loss or pinball loss. :math:`L(y_i, p_i) = \\frac{1}{k}\\sum_{j=1}^{k} (q_j - \\mathbb{1})(y_i - p_{i, j})` @@ -454,9 +474,9 @@ def initialise_prediction( return cn.zeros_like(self.quantiles) -class LogLossObjective(FitInterceptRegMixIn): - """The Log Loss objective function for binary and multi-class - classification problems. +class LogLossObjective(ClassificationObjective): + """The Log Loss objective function for binary and multi-class classification + problems. This objective function computes the log loss between the predicted and true labels. @@ -484,6 +504,7 @@ def transform(self, pred: cn.ndarray) -> cn.ndarray: assert len(pred.shape) == 2 if pred.shape[1] == 1: return self.one / (self.one + cn.exp(-pred)) + # softmax function s = cn.max(pred, axis=1) e_x = cn.exp(pred - s[:, cn.newaxis]) @@ -500,6 +521,8 @@ def initialise_prediction( raise ValueError("Expected labels to be non-zero whole numbers") num_class = int(cn.max(y) + 1) n_targets = num_class if num_class > 2 else 1 + if not boost_from_average: + return cn.zeros(n_targets, dtype=cn.float64) if n_targets == 1: prob = y.sum() / y.size return -cn.log(1 / prob - 1).reshape(1) @@ -508,9 +531,42 @@ def initialise_prediction( return cn.log(prob) -class ExponentialObjective(FitInterceptRegMixIn): - """Exponential loss objective function for binary classification. - Equivalent to the AdaBoost multiclass exponential loss in [1]. +class MultiLabelObjective(ClassificationObjective): + """Used for multi-label classification problems. i.e. the model can predict + more than one output class. + + We apply an independent sigmoid function/logloss to each class. + + See also: + :class:`legateboost.metrics.MultiLabelMetric` + """ + + def gradient(self, y: cn.ndarray, pred: cn.ndarray) -> GradPair: + return pred - y, pred * (self.one - pred) + + def transform(self, pred: cn.ndarray) -> cn.ndarray: + return self.one / (self.one + cn.exp(-pred)) + + def output_class(self, pred: cn.ndarray) -> cn.ndarray: + return cn.array(pred > 0.5, dtype=cn.int32).squeeze() + + def metric(self) -> MultiLabelMetric: + return MultiLabelMetric() + + def initialise_prediction( + self, y: cn.ndarray, w: cn.ndarray, boost_from_average: bool + ) -> cn.ndarray: + if not cn.all((y == 1.0) | (y == 0.0)): + raise ValueError("Expected labels to be in [0, 1]") + if not boost_from_average: + return cn.zeros((1, y.shape[1]), dtype=cn.float64) + prob = y.sum(axis=0) / y.shape[0] + return -cn.log(1 / prob - 1) + + +class ExponentialObjective(ClassificationObjective, FitInterceptRegMixIn): + """Exponential loss objective function for binary classification. Equivalent + to the AdaBoost multiclass exponential loss in [1]. Defined as: @@ -577,6 +633,7 @@ def initialise_prediction( "normal": NormalObjective, "log_loss": LogLossObjective, "exp": ExponentialObjective, + "multi_label": MultiLabelObjective, "quantile": QuantileObjective, "gamma_deviance": GammaDevianceObjective, "gamma": GammaObjective, diff --git a/legateboost/test/test_estimator.py b/legateboost/test/test_estimator.py index 5d02b421..fd293bc5 100644 --- a/legateboost/test/test_estimator.py +++ b/legateboost/test/test_estimator.py @@ -10,6 +10,7 @@ def test_init(): + # regressor np.random.seed(2) X = np.random.random((100, 10)) y = np.random.random((X.shape[0], 2)) @@ -21,6 +22,15 @@ def test_init(): model = model.fit(X, y, sample_weight=w) assert cn.allclose(model.model_init_, y[0:50].mean(axis=0)) + # classifier + y = np.random.randint(0, 2, X.shape[0]) + model = lb.LBClassifier(n_estimators=0, init="average").fit(X, y) + obj = lb.LogLossObjective() + p = y.mean() + assert cn.allclose(obj.transform(model.model_init_.reshape(-1, 1)), p) + model = lb.LBClassifier(n_estimators=0, init=None).fit(X, y) + assert cn.allclose(model.model_init_, 0.0) + @pytest.mark.parametrize("init", [None, "average"]) def test_update(init): @@ -133,6 +143,7 @@ def test_classifier(num_class, objective, base_models): proba = model.predict_proba(X) assert cn.all(proba >= 0) and cn.all(proba <= 1) assert cn.all(cn.argmax(proba, axis=1) == model.predict(X)) + assert cn.allclose(proba.sum(axis=1), cn.ones(X.shape[0])) loss = metric.metric(y, proba, cn.ones(y.shape[0])) train_loss = next(iter(eval_result["train"].values())) @@ -225,3 +236,16 @@ def test_iterator_methods(): assert list(model) == list(model.models_) for i, est in enumerate(model): assert est == model[i] + + +def test_multi_label(): + np.random.seed(2) + X = np.random.random((100, 10)) + y = np.random.randint(0, 2, (X.shape[0], 5)) + eval_result = {} + model = lb.LBClassifier( + n_estimators=5, base_models=(lb.models.Linear(),), objective="multi_label" + ).fit(X, y, eval_result=eval_result) + assert model.predict(X).shape == y.shape + assert model.predict_proba(X).shape == (X.shape[0], 5) + assert non_increasing(eval_result["train"]["multi_label"]) diff --git a/legateboost/test/test_objective.py b/legateboost/test/test_objective.py index 496dc475..aaf8823c 100644 --- a/legateboost/test/test_objective.py +++ b/legateboost/test/test_objective.py @@ -126,3 +126,16 @@ def test_exp(): ), False, ) + + +def test_multi_label(): + obj = lb.MultiLabelObjective() + g, h = obj.gradient( + cn.array([[1, 0], [0, 1]]), + cn.array([[[0.5, 0.5], [0.5, 0.5]], [[0.5, 0.5], [0.5, 0.5]]]), + ) + assert cn.allclose(g, cn.array([[-0.5, 0.5], [0.5, -0.5]])) + assert cn.allclose(h, cn.array([[0.25, 0.25], [0.25, 0.25]])) + + with pytest.raises(ValueError, match=r"Expected labels to be in \[0, 1\]"): + obj.initialise_prediction(cn.array([[1], [2]]), cn.array([[1.0], [1.0]]), False) diff --git a/legateboost/utils.py b/legateboost/utils.py index d8730fe0..935ad43c 100644 --- a/legateboost/utils.py +++ b/legateboost/utils.py @@ -156,14 +156,14 @@ def get_store(input: Any) -> LogicalStore: def solve_singular(a: cn.ndarray, b: cn.ndarray) -> cn.ndarray: - """Solve a singular linear system Ax = b for x. The same as - np.linalg.solve, but if A is singular, then we use Algorithm 3.3 from: + """Solve a singular linear system Ax = b for x. The same as np.linalg.solve, + but if A is singular, then we use Algorithm 3.3 from: - Nocedal, Jorge, and Stephen J. Wright, eds. Numerical optimization. - New York, NY: Springer New York, 1999. + Nocedal, Jorge, and Stephen J. Wright, eds. Numerical optimization. New York, + NY: Springer New York, 1999. - This progressively adds to the diagonal of the matrix until it is - non-singular. + This progressively adds to the diagonal of the matrix until it is non- + singular. """ # ensure we are doing all calculations in float 64 for stability a = a.astype(np.float64) @@ -202,8 +202,7 @@ def solve_singular(a: cn.ndarray, b: cn.ndarray) -> cn.ndarray: def sample_average( y: cn.ndarray, sample_weight: Optional[cn.ndarray] = None ) -> cn.ndarray: - """Compute weighted average on the first axis (usually the sample - dimension). + """Compute weighted average on the first axis (usually the sample dimension). Returns 0 if sum weight is zero or if the input is empty. """ diff --git a/pyproject.toml b/pyproject.toml index 092d8f3d..6b06ef11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,3 +128,7 @@ skip = [ skip_glob = [ "**/__init__.py" ] + +[tool.docformatter] +wrap-summaries = 82 +wrap-descriptions = 81