diff --git a/conda/environments/all_cuda-122.yaml b/conda/environments/all_cuda-122.yaml index cae2f766..0e440e99 100644 --- a/conda/environments/all_cuda-122.yaml +++ b/conda/environments/all_cuda-122.yaml @@ -40,6 +40,7 @@ dependencies: - rich - scikit-build-core>=0.10.0 - scikit-learn>=1.6 +- scikit-learn>=1.6,!=1.7.1 - seaborn>=0.13 - sphinx>=8.0,<8.2.0 - typing-extensions>=4.0 diff --git a/dependencies.yaml b/dependencies.yaml index 14a5bd4b..05c5fe53 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -128,6 +128,7 @@ dependencies: common: - output_types: [conda, pyproject, requirements] packages: + - scikit-learn>=1.6,!=1.7.1 # 1.7.1 has doc build issues - myst-parser>=4.0 - pydata-sphinx-theme>=0.16 # the ceiling on sphinx can be removed when https://github.com/spatialaudio/nbsphinx/issues/825 is resolved diff --git a/docs/source/index.rst b/docs/source/index.rst index cc0dd88a..cb3f9655 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,6 +11,7 @@ legate-boost documentation Introduction Contributing Python API Reference + Tutorial Indices and tables ================== diff --git a/docs/source/legate-boost.rst b/docs/source/legate-boost.rst new file mode 100644 index 00000000..ce76bec7 --- /dev/null +++ b/docs/source/legate-boost.rst @@ -0,0 +1,442 @@ +.. _legate-boost: + + +============= +legate-boost +============= + +This article assumes familiarity with the basic usage of gradient boosting +libraries such as XGBoost or LightGBM, as well as ``cupynumeric`` for GPU-accelerated +array computations. In this tutorial, these libraries are used for efficient model +training, large-scale data handling, and accelerating computations across CPUs and GPUs. + +What is legate-boost? +===================== + +In scenarios where high-performance training is needed across large +datasets or distributed hardware, ``legate-boost`` offers a scalable +alternative. ``legate-boost`` is an advanced gradient boosting library built +on the ``Legate`` and Legion parallel programming frameworks. Unlike +traditional boosting libraries such as XGBoost or LightGBM, ``legate-boost`` +provides a unified infrastructure that seamlessly scales across CPU's and +GPU's, supporting both single-node and distributed training while +integrating naturally with ``cupynumeric`` workflows for efficient +end-to-end data processing. It enables users to define not only +conventional boosted decision trees but also hybrid ensembles combining +trees, kernel ridge regression, linear models, or neural networks, all +written in Python with minimal code changes. + +These models are automatically parallelized and executed efficiently +without manual data movement or partitioning. ``legate-boost`` emphasizes +architectural simplicity, extensibility, and performance, delivering +state-of-the-art results on tabular data while leveraging the full +computational power of modern heterogeneous hardware. + +Please refer to `Distributed Computing with cupynumeric`_ +and `legate boost`_ for more +information and detailed instructions on installation and setup. + +.. _Distributed Computing with cupynumeric: https://docs.nvidia.com/cupynumeric/latest/user/tutorial.html + +.. _legate boost: https://github.com/rapidsai/legate-boost?tab=readme-ov-file#installation. + +Usage +===== + +``legate-boost`` offers two main estimator types: + +- ``LBRegressor`` for regression tasks +- ``LBClassifier`` for classification tasks + +These estimators follow a similar interface to those in XGboost, +making them easy to integrate into existing machine learning pipelines. + +Regression with LBRegressor +--------------------------- + +The ``LBRegressor`` estimator is used to predict continuous values such as +house prices, temperature, or sales forecasting. The following code +demonstrates how to create an instance of the ``LBRegressor`` model, use the +``fit()`` function to train it on a dataset, and then apply the ``predict()`` +function to generate predictions on new data. Here’s how to set it up: + + +.. literalinclude:: ../../examples/tutorial_examples/LBRegressor.py + :language: python + +In this example: +~~~~~~~~~~~~~~~~ + +- ``LBRegressor`` is initialized with 100 boosting estimators. +- The ``fit()`` method trains the model using the input features (X_train) + and target values (y_train). +- After training, the ``predict()`` method is used to make predictions on + the test set (X_test). + +This represents a typical workflow for applying a regression model using +``legate-boost``. The ``LBRegressor`` estimator offers several configurable +options, such as base_model and learning_rate, to help optimize model +performance. For a comprehensive list of features and parameters, refer +to the `official documentation`_. + +.. _official documentation: https://rapidsai.github.io/legate-boost/api/estimators.html + +Classification with LBClassifier +--------------------------------- + +The ``LBClassifier`` is designed for predicting categorical outcomes and +supports both binary and multi-class classification tasks. It is ideal +for a wide range of applications, including spam detection, image +classification, and sentiment analysis. + +The example below demonstrates how to implement a classification model +using the ``LBClassifier`` estimator from ``legate-boost``: + +.. literalinclude:: ../../examples/tutorial_examples/LBClassifier.py + :language: python + +In this example: +~~~~~~~~~~~~~~~~ + +- ``LBClassifier`` (n_estimators=50) sets up a classifier with 50 boosting + rounds. + +- ``fit(X_train, y_train)`` learns the patterns from your training dataset. + +- ``predict(X_test)`` outputs predicted class labels for the test dataset. + +Just like the regressor, the ``LBClassifier`` follows a clean and intuitive +workflow. It provides additional options and advanced configurations to +optimize model performance. For more detailed information, refer to the +legate-boost `estimators`_ documentation. + +.. _estimators: https://rapidsai.github.io/legate-boost/api/estimators.html#legateboost.LBClassifier + +Example 1 +========= + +Here is an example of using ``legate-boost`` to build a regression model on +the California housing dataset. It showcases key features like scalable training across GPUs/nodes, +customizable base models, and adjustable learning rates. + +About dataset +------------- + +The California housing dataset is a classic benchmark dataset containing +information collected from California districts in the 1990 census. Each +record describes a block group (a neighborhood-level area), including +predictors such as: + +- Median income of residents +- Average house age +- Average number of rooms and bedrooms +- Population and household count +- Latitude and longitude + +The target variable is the median house value in that block group. +This dataset is often used to illustrate regression techniques and +assess predictive performance on real-world tabular data. + +About this implementation +------------------------- + +The following code creates a ``legate-boost`` regression model using +``LBRegressor``, which trains a gradient boosting model optimized for +multi-GPU and multi-node environments. The model is configured to use +100 boosting rounds (n_estimators=100), with each round adding a +decision tree (lb.models.Tree) limited to a maximum depth of 8. The loss +function is set to squared_error, suitable for regression tasks as it +minimizes mean squared error. A learning_rate of 0.1 controls how much +each tree contributes to the final prediction, balancing speed and +stability. The verbose=True flag enables logging during training, +allowing to monitor progress and internal operations. + +Code module +----------- + +.. literalinclude:: ../../examples/tutorial_examples/housing.py + :language: python + +This simple example demonstrates how to train a regression model on the +california housing dataset using ``legate-boost``. Although the code looks +similar to standard XGBoost, legate-boost automatically enables +multi-GPU and multi-node computation. ``legate-boost`` achieves multi-GPU +and multi-node scaling through its integration with cupynumeric and the +Legion runtime. Unlike traditional GPU libraries that allocate data to a +single device, ``cupynumeric`` creates ``logical arrays`` and abstract +representations of the data that are not bound to one GPU. The ``Legate`` +automatically partitions these ``logical arrays`` into physical chunks and +maps them across all available GPUs and nodes. + +During training, operations such as histogram building, gradient +computation, and tree construction are expressed as parallel tasks. +``Legate`` schedules these tasks close to where the data resides, minimizing +communication overhead. When synchronization is needed (e.g., to combine +histograms from multiple GPUs), it is handled by ``legate-mpi-wrapper`` and +``realm-gex-wrapper``, so we never have to write MPI or manage explicit GPU +memory transfers. + +Running on CPU and GPU +---------------------- + +CPU execution +~~~~~~~~~~~~~ + +To run with CPU, use the following command. + +.. code-block:: sh + + legate --cpus 1 --gpus 0 ./housing.py + +This produces the following output: + +.. code-block:: text + + The training time for housing exp is: 1742.303000 milliseconds + + +GPU execution +~~~~~~~~~~~~~ + +To run with GPU, use the following command. + +.. code-block:: sh + + legate --gpus 1 ./housing.py + +This produces the following output: + +.. code-block:: text + + The training time for housing exp is: 831.949000 milliseconds + + +Example 2 +========= + +This example demonstrates how legate-boost can be applied to the ``Give +Me Some Credit`` dataset (OpenML data_id: 46929) to build a +classification model using ensemble learning by combining different +model types. It also highlights the integration of ``Legate Dataframe`` with +``legate-boost`` to enable distributed training across multi-GPU and +multi-node environments, showcasing scalable machine learning on the +credit score dataset. + +About the dataset +----------------- + +The ``Give Me Some Credit`` dataset is a financial risk prediction dataset +originally introduced in a Kaggle competition. It includes anonymized +credit and demographic data for individuals, with the goal of predicting +whether a person is likely to experience serious financial distress +within the next two years. + +Each record represents an individual and includes features such as: + +- Revolving utilization of unsecured credit lines +- Age +- Number of late payments (30–59, 60–89, and 90+ days past due) +- Debt ratio +- Monthly income +- Number of open credit lines and loans +- Number of dependents + +The target variable is binary (0 = no distress, 1 = distress), +indicating the likelihood of future financial trouble. + +About this implementation +------------------------- + +This implementation will focus on demonstrating ``legate-boost’s`` flexible +model ensembling capabilities, specifically: + +- Tree-based gradient boosting models, ideal for structured/tabular + data. +- Neural network-based classifiers, allowing hybrid or deep learning + approaches. + +By leveraging ``legate-boost``, we can ensemble these two models and +efficiently train and evaluate both model types on GPUs or CPUs, +showcasing scalable performance for large tabular datasets in financial +risk prediction. + +The pipeline begins with importing required libraries and its functions +and also loading the dataset using fetch_openml. Depending on hardware +availability, the data is initially handled either with cuDF (for GPU +execution) or pandas (for CPU execution). The dataset is then wrapped +into a ``LogicalTable``, the distributed data representation used by ``Legate +Dataframe``. ``LogicalTables`` internally break data into ``logical columns``, +enabling Legate’s runtime to partition, distribute, and schedule +computations across multiple GPUs and nodes. + +.. literalinclude:: ../../examples/tutorial_examples/creditscore.py + :language: python + :start-after: [import libraries] + :end-before: [covert to LogicalTable end] + +Let’s see how data preprocessing is performed directly on the +``LogicalTable``. Missing values in key columns (MonthlyIncome and +NumberOfDependents) are filled using median imputation through the +replace_nulls operation. These operations are executed in parallel +across distributed partitions of the ``LogicalTable``, avoiding centralized +bottlenecks. Because ``LogicalTable's`` are immutable, a new ``LogicalTable`` +with updated LogicalColumn’s is created after preprocessing. The cleaned +data is then converted into a cupynumeric array, Legate’s +GPU-accelerated array type that leverages logical partitioning for +distributed computation. This enables the subsequent machine learning +tasks to execute efficiently across multiple GPUs or nodes. + +.. literalinclude:: ../../examples/tutorial_examples/creditscore.py + :language: python + :start-after: [Replace nulls] + :end-before: [convert to cupynumeric array end] + +As we have a data_arr backed by ``cupynumeric``, we first split the dataset +into training and testing subsets, which are then passed to ``legate-boost`` +for efficient training across available hardware resources. The model is +built using ``legate-boost’s`` ensemble framework (LBClassifier), which +allows combining multiple types of base learners into a single unified +model. + +In this example, the ensemble consists of a Decision Tree +(lb.models.Tree) with max_depth=5, enabling the capture of complex +non-linear decision boundaries by splitting the feature space +hierarchically up to 5 levels deep, and a Neural Network (lb.models.NN) +with two hidden layers of 10 neurons each (hidden_layer_sizes=(10,10)), +trained for max_iter=10 epochs with verbose=True to monitor progress. By +combining a tree-based model with a neural network, ``legate-boost`` +leverages the interpretability and rule-based decision-making of trees +together with the ability of neural networks to model intricate, +high-dimensional relationships. This ensemble design results in a more +accurate and robust classifier than either model could achieve +individually. + +.. literalinclude:: ../../examples/tutorial_examples/creditscore.py + :language: python + :start-after: [preparing data for training and testing] + :end-before: [training end] + +The trained ensemble model is used to generate predictions on the test +set, and its accuracy is evaluated using ``accuracy_score()``. + +.. literalinclude:: ../../examples/tutorial_examples/creditscore.py + :language: python + :start-after: [Prediction] + :end-before: [Inference] + +This workflow illustrates how ``Legate Dataframe`` provides a scalable +preprocessing layer, ``cupynumeric`` arrays enable distributed GPU +computation, and ``legate-boost`` delivers a flexible ensemble learning +framework capable of leveraging multi-node, multi-GPU infrastructure +efficiently. + +Running on CPU and GPU +---------------------- + +CPU execution +~~~~~~~~~~~~~ + +To run with CPU, use the following command. + +.. code-block:: sh + + legate --cpus 8 --gpus 0 ./creditscore.py + +This produces the following output: + +.. code-block:: text + + Accuracy: 0.9343 + The training time for credit score exp is : 10912.335000 ms + +GPU execution +~~~~~~~~~~~~~ + +To run with GPU, use the following command. + +.. code-block:: sh + + legate --gpus 2 ./creditscore.py + +This produces the following output: + +.. code-block:: text + + Accuracy: 0.9357 + The training time for credit score exp is : 2688.233000 ms + +Inference performance +===================== + +Let’s explore how ``cupynumeric`` can be leveraged to measure inference +performance statistics seamlessly across both CPU and GPU all without +modifying the code. In this example, we evaluate a pre-trained machine +learning model by calculating key metrics such as mean, median, minimum, +maximum, variance, and standard deviation of inference times. The pretrained model +is used here and the predictions are executed multiple times on +the test dataset. By utilizing ``cupynumeric`` arrays, the timing results +are efficiently processed while ensuring compatibility with both CPU and +GPU environments. This approach provides a simple yet powerful way to +compare inference performance across hardware, offering clear insights +into the speedup and variability achieved with GPU acceleration. + + +.. literalinclude:: ../../examples/tutorial_examples/creditscore.py + :language: python + :start-after: [Inference] + + +Running on CPU and GPU +---------------------- + +CPU execution +~~~~~~~~~~~~~ + +To run with CPU, use the following command. + +.. code-block:: sh + + legate --cpus 8 --gpus 0 ./inference.py + +This produces the following output: + +.. code-block:: text + + Mean: 167.97 ms + Median: 170.25 ms + Min: 161.46 ms + Max: 176.31 ms + Variance: 23.52 ms + standard deviation: 4.85 ms + + +GPU execution +~~~~~~~~~~~~~ + +To run with GPU, use the following command. + +.. code-block:: sh + + legate --gpus 2 ./inference.py + +This produces the following output: + +.. code-block:: text + + Mean: 132.44 ms + Median: 131.58 ms + Min: 130.56 ms + Max: 136.72 ms + Variance: 3.42 ms + standard deviation: 1.85 ms + +These results clearly show the performance benefits of running inference +on a GPU compared to a CPU using ``cupynumeric`` arrays. On the CPU, the +model achieved a mean inference time of approximately 265.66 ms, +with relatively low variability (standard deviation ~\ 10.83 ms). In +contrast, the GPU significantly reduced the mean inference time to +around 122.35 ms, representing more than a 2x speedup, with even +lower variability (standard deviation ~\ 1.34 ms). This highlights +how ``cupynumeric`` enables the same code to seamlessly scale across CPU and +GPU, allowing both accurate performance benchmarking and efficient model +deployment across heterogeneous hardware. diff --git a/examples/tutorial_examples/LBClassifier.py b/examples/tutorial_examples/LBClassifier.py new file mode 100644 index 00000000..ab13f9ba --- /dev/null +++ b/examples/tutorial_examples/LBClassifier.py @@ -0,0 +1,22 @@ +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split + +import legateboost as lb + +# creating synthetic dataset +X, y = make_classification(n_samples=100, n_features=4, n_classes=2, random_state=42) + +# splitting the data +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) + +# classification model with 50 estimators +classification_model = lb.LBClassifier(n_estimators=50) + +# train the model +classification_model.fit(X_train, y_train) + +# predictions +y_pred = classification_model.predict(X_test) +print(y_pred) diff --git a/examples/tutorial_examples/LBRegressor.py b/examples/tutorial_examples/LBRegressor.py new file mode 100644 index 00000000..6a35f0a6 --- /dev/null +++ b/examples/tutorial_examples/LBRegressor.py @@ -0,0 +1,23 @@ +from sklearn.datasets import make_regression +from sklearn.model_selection import train_test_split + +import legateboost as lb + +# creating synthetic dataset +X, y = make_regression(n_samples=100, n_features=4, noise=8, random_state=42) + +# splitting the data +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) + +# regression model with 100 estimators +regression_model = lb.LBRegressor(n_estimators=100) + +# fit the model +regression_model.fit(X_train, y_train) + +# predict +y_pred = regression_model.predict(X_test) + +print(y_pred) diff --git a/examples/tutorial_examples/creditscore.py b/examples/tutorial_examples/creditscore.py new file mode 100644 index 00000000..b6574b29 --- /dev/null +++ b/examples/tutorial_examples/creditscore.py @@ -0,0 +1,122 @@ +# [import libraries] +import os + +import numpy as np +import pandas as pd +import pyarrow as pa +from legate_dataframe.lib.core.column import LogicalColumn +from legate_dataframe.lib.core.table import LogicalTable +from legate_dataframe.lib.replace import replace_nulls +from sklearn.datasets import fetch_openml +from sklearn.metrics import accuracy_score + +import cupynumeric as cpn +import legate.core as lg +import legateboost as lb +from legate.timing import time + +rt = lg.get_legate_runtime() + +# [import data] +data = fetch_openml(data_id=46929, as_frame=True) +df = pd.DataFrame(data.data, columns=data.feature_names) +df["Target"] = data.target.map({"No": 0, "Yes": 1}).astype(np.int8) + +if os.environ.get("CI"): + df = df.sample(n=100, random_state=42).reset_index(drop=True) + +# [convert to LogicalTable] +df_arrow = pa.Table.from_pandas(df) +ldf = LogicalTable.from_arrow(df_arrow) +# [covert to LogicalTable end] + +# [Replace nulls] +median_salary = df["MonthlyIncome"].median() +median_dependents = df["NumberOfDependents"].median() + +mmi = LogicalColumn(replace_nulls(LogicalColumn(ldf["MonthlyIncome"]), median_salary)) +mnd = LogicalColumn( + replace_nulls(LogicalColumn(ldf["NumberOfDependents"]), median_dependents) +) + +# [Create new LogicalTable with updated columns] + +features = ldf.get_column_names() +nldf = LogicalTable( + [ldf[0], ldf[1], ldf[2], ldf[3], mmi, ldf[5], ldf[6], ldf[7], ldf[8], mnd, ldf[10]], + features, +) +# [Convert to cupynumeric array] + +data_arr = nldf.to_array() + +# [convert to cupynumeric array end] + +# [preparing data for training and testing] +x = data_arr[:, :-1] # all columns except last +y = data_arr[:, -1] + +# [Splitting the data into training and testing] +num_samples = x.shape[0] +split_ratio = 0.8 +split_index = int(num_samples * split_ratio) + +x_train = x[:split_index] +y_train = y[:split_index] +x_test = x[split_index:] +y_test = y[split_index:] + +# [training] +rt.issue_execution_fence() +start = time() +nn_iter = 2 if os.environ.get("CI") else 10 +hidden_layers = (2, 2) if os.environ.get("CI") else (10, 10) + +model = lb.LBClassifier( + base_models=( + lb.models.Tree(max_depth=5), + lb.models.NN(max_iter=nn_iter, hidden_layer_sizes=hidden_layers, verbose=True), + ) +).fit(x_train, y_train) + +rt.issue_execution_fence() +end = time() +# [training end] + +# [Prediction] +predictions = model.predict(x_test) +print(type(predictions)) + +# [Evaluation] +acc = accuracy_score(y_test, predictions) +print("Accuracy:", acc) +print(f"\nThe training time for creditscore exp is: {(end - start)/1000:.6f} ms") + +# [Inference] +rt = lg.get_legate_runtime() +timings = [] + +for _ in range(10): + rt.issue_execution_fence() + start = time() + model.predict(x_test) + rt.issue_execution_fence() + end = time() + timings.append(end - start) + +timings = timings[1:] +timings_gpu = cpn.array(timings) + +mean_time = cpn.mean(timings_gpu) +median_time = cpn.median(timings_gpu) +min_time = cpn.min(timings_gpu) +max_time = cpn.max(timings_gpu) +var_time = cpn.var(timings_gpu) +std = cpn.sqrt(var_time) + +print(f"Mean: {float(mean_time)/1000:.2f} ms") +print(f"Median: {float(median_time)/1000:.2f} ms") +print(f"Min: {float(min_time)/1000:.2f} ms") +print(f"Max: {float(max_time)/1000:.2f} ms") +print(f"Variance: {float(var_time)/1000000:.2f} ms") +print(f"standard deviation: {float(std)/1000:.2f} ms") diff --git a/examples/tutorial_examples/housing.py b/examples/tutorial_examples/housing.py new file mode 100644 index 00000000..4054ec0f --- /dev/null +++ b/examples/tutorial_examples/housing.py @@ -0,0 +1,46 @@ +# [Import libraries] +import os + +from sklearn.datasets import fetch_california_housing, make_regression +from sklearn.metrics import mean_squared_error +from sklearn.model_selection import train_test_split + +import legateboost as lb +from legate.timing import time + +# [Import data] +if os.environ.get("CI"): + X, y = make_regression(n_samples=100, n_features=5, n_targets=1, random_state=42) + total_estimators = 10 +else: + data = fetch_california_housing() + X, y = data.data, data.target + total_estimators = 100 + +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 +) + +# [Create and fit Legate Boost regressor] +model = lb.LBRegressor( + n_estimators=total_estimators, + base_models=(lb.models.Tree(max_depth=8),), + objective="squared_error", + learning_rate=0.1, + verbose=True, +) + +start = time() +model.fit(X_train, y_train) +end = time() + +# [Prediction] +istart = time() +y_pred = model.predict(X_test) +iend = time() + +# [Evaluation] +mse = mean_squared_error(y_test, y_pred) +print(f"Test MSE: {mse:.4f}") +print(f"\nThe training time for housing exp is: {(end - start)/1000:.6f} ms") +print(f"\nThe inference time for housing exp is {(iend - istart)/1000:.6f} ms") diff --git a/legateboost/callbacks.py b/legateboost/callbacks.py index e17b26dc..bacf2693 100644 --- a/legateboost/callbacks.py +++ b/legateboost/callbacks.py @@ -57,8 +57,8 @@ def after_iteration( class EarlyStopping(TrainingCallback): - """Callback for early stopping during training. The last evaluation dataset is - used for early stopping. + """Callback for early stopping during training. The last evaluation dataset + is used for early stopping. Args: rounds (int): The number of rounds to wait for improvement before stopping. diff --git a/legateboost/encoder.py b/legateboost/encoder.py index 16737fd7..69008ba4 100644 --- a/legateboost/encoder.py +++ b/legateboost/encoder.py @@ -16,18 +16,18 @@ class TargetEncoder(TransformerMixin, BaseEstimator, PickleCupynumericMixin): - """TargetEncoder is a transformer that encodes categorical features using the - mean of the target variable. When `fit_transform` is called, a cross- + """TargetEncoder is a transformer that encodes categorical features using + the mean of the target variable. When `fit_transform` is called, a cross- validation procedure is used to generate encodings for each training fold, which are then applied to the test fold. `fit().transform()` differs from `fit_transform()` in that the former fits the encoder on all the data and - generates encodings for each feature. This encoder is modelled on the sklearn - TargetEncoder with only minor differences in how the CV folds are generated. - As it is difficult to rearrange and gather data from each fold in distributed - environment, training rows are kept in place and then assigned a cv fold by - generating a random integer in the range [0, n_folds). As per sklearn, when - smooth="auto", an empirical Bayes estimate per [#]_ is used to avoid - overfitting. + generates encodings for each feature. This encoder is modelled on the + sklearn TargetEncoder with only minor differences in how the CV folds are + generated. As it is difficult to rearrange and gather data from each fold + in distributed environment, training rows are kept in place and then + assigned a cv fold by generating a random integer in the range [0, + n_folds). As per sklearn, when smooth="auto", an empirical Bayes estimate + per [#]_ is used to avoid overfitting. .. [#] Micci-Barreca, Daniele. "A preprocessing scheme for high-cardinality categorical attributes in classification and prediction problems." ACM SIGKDD explorations newsletter 3.1 (2001): 27-32. @@ -264,8 +264,8 @@ def _get_category_means( """Compute some label summary statistics for each category in the input data. - Returns a 3D array of shape (n_categories, n_outputs, 2) containing the - sum, count of the labels for each category. + Returns a 3D array of shape (n_categories, n_outputs, 2) + containing the sum, count of the labels for each category. """ task = get_legate_runtime().create_auto_task( user_context, user_lib.cffi.TARGET_ENCODER_MEAN diff --git a/legateboost/legateboost.py b/legateboost/legateboost.py index 80c8beb2..6f283302 100644 --- a/legateboost/legateboost.py +++ b/legateboost/legateboost.py @@ -313,8 +313,8 @@ def update( eval_result: EvalResult = {}, ) -> Self: """Update a gradient boosting model from the training set (X, y). This - method does not add any new models to the ensemble, only updates existing - models to fit the new data. + method does not add any new models to the ensemble, only updates + existing models to fit the new data. Parameters ---------- @@ -476,8 +476,8 @@ def __iter__(self) -> Any: return iter(self.models_) def __mul__(self, scalar: Any) -> Self: - """Gradient boosted models are linear in the predictions before the non- - linear link function is applied. This means that the model can be + """Gradient boosted models are linear in the predictions before the + non- linear link function is applied. This means that the model can be multiplied by a scalar, which subsequently scales all raw output predictions. This is useful for ensembling models. @@ -550,8 +550,8 @@ def global_attributions( n_samples: int = 5, check_efficiency: bool = False, ) -> Tuple[cn.array, cn.array]: - r"""Compute global feature attributions for the model. Global attributions - show the effect of a feature on a model's loss function. + r"""Compute global feature attributions for the model. Global + attributions show the effect of a feature on a model's loss function. We use a Shapley value approach to compute the attributions: :math:`Sh_i(v)=\frac{1}{|N|!} \sum_{\sigma \in \mathfrak{S}_d} \big[ v([\sigma]_{i-1} \cup\{i\}) - v([\sigma]_{i-1}) \big],` @@ -612,10 +612,11 @@ def local_attributions( n_samples: int = 5, check_efficiency: bool = False, ) -> Tuple[cn.array, cn.array]: - r"""Local feature attributions for model predictions. Shows the effect of a - feature on each output prediction. See the definition of Shapley values in - :func:`~legateboost.BaseModel.global_attributions`, where the :math:`v` - function is here the model prediction instead of the loss function. + r"""Local feature attributions for model predictions. Shows the effect + of a feature on each output prediction. See the definition of Shapley + values in :func:`~legateboost.BaseModel.global_attributions`, where the + :math:`v` function is here the model prediction instead of the loss + function. Parameters ---------- @@ -749,8 +750,8 @@ def partial_fit( eval_set: List[Tuple[cn.ndarray, ...]] = [], eval_result: EvalResult = {}, ) -> LBBase: - """This method is used for incremental (online) training of the model. An - additional `n_estimators` models will be added to the ensemble. + """This method is used for incremental (online) training of the model. + An additional `n_estimators` models will be added to the ensemble. Parameters ---------- @@ -927,8 +928,8 @@ def partial_fit( eval_result: EvalResult = {}, ) -> LBBase: """This method is used for incremental fitting on a batch of samples. - Requires the classes to be provided up front, as they may not be inferred - from the first batch. + Requires the classes to be provided up front, as they may not be + inferred from the first batch. Parameters ---------- @@ -1032,8 +1033,8 @@ def fit( return self def predict_raw(self, X: cn.ndarray) -> cn.ndarray: - """Predict pre-transformed values for samples in X. E.g. before applying a - sigmoid function. + """Predict pre-transformed values for samples in X. E.g. before + applying a sigmoid function. Parameters ---------- diff --git a/legateboost/metrics.py b/legateboost/metrics.py index 6f4065b5..68e6a028 100644 --- a/legateboost/metrics.py +++ b/legateboost/metrics.py @@ -145,8 +145,8 @@ def name(self) -> str: class GammaLLMetric(BaseMetric): - """The mean negative log likelihood of the labels, given parameters predicted - by the model.""" + """The mean negative log likelihood of the labels, given parameters + predicted by the model.""" @override def metric(self, y: cn.ndarray, pred: cn.ndarray, w: cn.ndarray) -> cn.ndarray: @@ -253,8 +253,8 @@ def name(self) -> str: class LogLossMetric(BaseMetric): - """Class for computing the logarithmic loss (logloss) metric between the true - labels and predicted labels. + """Class for computing the logarithmic loss (logloss) metric between the + true labels and predicted labels. For binary classification: diff --git a/legateboost/models/base_model.py b/legateboost/models/base_model.py index a1e88011..39987f88 100644 --- a/legateboost/models/base_model.py +++ b/legateboost/models/base_model.py @@ -11,9 +11,9 @@ class BaseModel(PickleCupynumericMixin, ABC): """Base class for all models in LegateBoost. - Defines the interface for fitting, updating, and predicting a model, as well - as string representation and equality comparison. Implement these methods to - create a custom model. + Defines the interface for fitting, updating, and predicting a model, + as well as string representation and equality comparison. Implement + these methods to create a custom model. """ def set_random_state(self, random_state: np.random.RandomState) -> "BaseModel": @@ -27,7 +27,8 @@ def fit( g: cn.ndarray, h: cn.ndarray, ) -> "BaseModel": - """Fit the model to a second order Taylor expansion of the loss function. + """Fit the model to a second order Taylor expansion of the loss + function. Parameters ---------- diff --git a/legateboost/models/krr.py b/legateboost/models/krr.py index 31af8d5a..ef0de7ef 100644 --- a/legateboost/models/krr.py +++ b/legateboost/models/krr.py @@ -35,9 +35,9 @@ def rbf(x: cn.ndarray, sigma: float) -> cn.ndarray: class KRR(BaseModel): - """Kernel Ridge Regression model using the Nyström approximation. The accuracy - of the approximation is governed by the parameter `n_components` <= `n`. - Effectively, `n_components` rows will be randomly sampled (without + """Kernel Ridge Regression model using the Nyström approximation. The + accuracy of the approximation is governed by the parameter `n_components` + <= `n`. Effectively, `n_components` rows will be randomly sampled (without replacement) from X in each boosting iteration. The kernel is fixed to be the RBF kernel: diff --git a/legateboost/models/linear.py b/legateboost/models/linear.py index ec34594e..f8757772 100644 --- a/legateboost/models/linear.py +++ b/legateboost/models/linear.py @@ -9,11 +9,11 @@ class Linear(BaseModel): - """Generalised linear model. Boosting linear models is equivalent to fitting a - single linear model where each boosting iteration is a newton step. Note that - the l2 penalty is applied to the weights of each model, as opposed to the sum - of all models. This can lead to different results when compared to fitting a - linear model with sklearn. + """Generalised linear model. Boosting linear models is equivalent to + fitting a single linear model where each boosting iteration is a newton + step. Note that the l2 penalty is applied to the weights of each model, as + opposed to the sum of all models. This can lead to different results when + compared to fitting a linear model with sklearn. It is recommended to normalize the data before fitting. This ensures regularisation is evenly applied to all features and prevents numerical issues. diff --git a/legateboost/objectives.py b/legateboost/objectives.py index 52865bb3..eafa0b98 100644 --- a/legateboost/objectives.py +++ b/legateboost/objectives.py @@ -100,12 +100,13 @@ def initialise_prediction( class ClassificationObjective(BaseObjective): - """Extension of BaseObjective for classification problems, use can optionaly - define a method of extracting a class output from probabilities.""" + """Extension of BaseObjective for classification problems, use can + optionaly define a method of extracting a class output from + probabilities.""" def output_class(self, pred: cn.ndarray) -> cn.ndarray: - """Defined how to output class labels from transfored output. This may be - as simple as argmax over probabilities. + """Defined how to output class labels from transfored output. This may + be as simple as argmax over probabilities. Args: pred (cn.ndarray): The transformed predictions. @@ -339,8 +340,8 @@ def initialise_prediction( class GammaObjective(FitInterceptRegMixIn, Forecast): - """Regression with the :math:`\\Gamma` distribution function using the shape - scale parameterization.""" + """Regression with the :math:`\\Gamma` distribution function using the + shape scale parameterization.""" @override def gradient(self, y: cn.ndarray, pred: cn.ndarray) -> GradPair: @@ -419,7 +420,8 @@ def var(self, param: cn.ndarray) -> cn.ndarray: class QuantileObjective(BaseObjective): - """Minimises the quantile loss, otherwise known as check loss or pinball loss. + """Minimises the quantile loss, otherwise known as check loss or pinball + loss. :math:`L(y_i, p_i) = \\frac{1}{k}\\sum_{j=1}^{k} (q_j - \\mathbb{1})(y_i - p_{i, j})` @@ -475,8 +477,8 @@ def initialise_prediction( class LogLossObjective(ClassificationObjective): - """The Log Loss objective function for binary and multi-class classification - problems. + """The Log Loss objective function for binary and multi-class + classification problems. This objective function computes the log loss between the predicted and true labels. @@ -565,8 +567,8 @@ def initialise_prediction( class ExponentialObjective(ClassificationObjective, FitInterceptRegMixIn): - """Exponential loss objective function for binary classification. Equivalent - to the AdaBoost multiclass exponential loss in [1]. + """Exponential loss objective function for binary classification. + Equivalent to the AdaBoost multiclass exponential loss in [1]. Defined as: diff --git a/legateboost/utils.py b/legateboost/utils.py index c38d62d0..fb7c3f95 100644 --- a/legateboost/utils.py +++ b/legateboost/utils.py @@ -156,8 +156,8 @@ def get_store(input: Any) -> LogicalStore: def solve_singular(a: cn.ndarray, b: cn.ndarray) -> cn.ndarray: - """Solve a singular linear system Ax = b for x. The same as np.linalg.solve, - but if A is singular, then we use Algorithm 3.3 from: + """Solve a singular linear system Ax = b for x. The same as + np.linalg.solve, but if A is singular, then we use Algorithm 3.3 from: Nocedal, Jorge, and Stephen J. Wright, eds. Numerical optimization. New York, NY: Springer New York, 1999. @@ -202,7 +202,8 @@ def solve_singular(a: cn.ndarray, b: cn.ndarray) -> cn.ndarray: def sample_average( y: cn.ndarray, sample_weight: Optional[cn.ndarray] = None ) -> cn.ndarray: - """Compute weighted average on the first axis (usually the sample dimension). + """Compute weighted average on the first axis (usually the sample + dimension). Returns 0 if sum weight is zero or if the input is empty. """