Skip to content

Conversation

@RAMitchell
Copy link
Contributor

Implements #225

This PR combines the model initialisation term as well as individual onnx models together into a serialised estimator.

I will likely only implement the predict_raw method here and leave predict_proba for anther PR as it will require e.g. softmax transforms.

@RAMitchell RAMitchell requested a review from Copilot March 31, 2025 14:41
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements ONNX model serialization (phase 2) by combining model initialization with individual ONNX models into a serialized estimator. Key changes include updating dependencies to include onnxruntime, adding new test functions to verify ONNX predictions, and modifying each model’s to_onnx method (and related functions) to accept an explicit data type parameter and use consistent input/output naming.

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

File Description
pyproject.toml, dependencies.yaml, conda YAML Updated dependencies to include onnxruntime>=1.21
legateboost/test/test_onnx.py Added new functions for ONNX predictions and updated test naming
legateboost/models/tree.py, nn.py, linear.py, krr.py Modified to_onnx methods to accept X_dtype and standardized input/output names
legateboost/legateboost.py Introduced _make_onnx_init and updated to_onnx to merge ONNX models

Copy link
Contributor

@seberg seberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks good to me, commenting since I suspect the classification needs a bit work. This is tricky to test!

Thinking about the predict_function= argument, but it seems good to me. (The predict_raw seems a bit duplicating the "predict".)

I should look closer at some of the ONNX code probably.

assert onnx_pred.dtype == pred.dtype
assert pred.shape == onnx_pred.shape
number_wrong = np.sum(
np.abs(pred - onnx_pred) > (1e-2 if X.dtype == np.float32 else 1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Predictions are alway similarly sized (i.e. 0-1)? Just curious if it would make sense to allow a relative deviation.

@RAMitchell
Copy link
Contributor Author

Cupynumeric test failures here require https://github.com/nv-legate/legate.internal/pull/2177 - we need to wait for nightlies to become available.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants