-
Notifications
You must be signed in to change notification settings - Fork 11
Implement ONNX model serialisation phase 2 #227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements ONNX model serialization (phase 2) by combining model initialization with individual ONNX models into a serialized estimator. Key changes include updating dependencies to include onnxruntime, adding new test functions to verify ONNX predictions, and modifying each model’s to_onnx method (and related functions) to accept an explicit data type parameter and use consistent input/output naming.
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| pyproject.toml, dependencies.yaml, conda YAML | Updated dependencies to include onnxruntime>=1.21 |
| legateboost/test/test_onnx.py | Added new functions for ONNX predictions and updated test naming |
| legateboost/models/tree.py, nn.py, linear.py, krr.py | Modified to_onnx methods to accept X_dtype and standardized input/output names |
| legateboost/legateboost.py | Introduced _make_onnx_init and updated to_onnx to merge ONNX models |
seberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The approach looks good to me, commenting since I suspect the classification needs a bit work. This is tricky to test!
Thinking about the predict_function= argument, but it seems good to me. (The predict_raw seems a bit duplicating the "predict".)
I should look closer at some of the ONNX code probably.
| assert onnx_pred.dtype == pred.dtype | ||
| assert pred.shape == onnx_pred.shape | ||
| number_wrong = np.sum( | ||
| np.abs(pred - onnx_pred) > (1e-2 if X.dtype == np.float32 else 1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Predictions are alway similarly sized (i.e. 0-1)? Just curious if it would make sense to allow a relative deviation.
|
Cupynumeric test failures here require https://github.com/nv-legate/legate.internal/pull/2177 - we need to wait for nightlies to become available. |
Implements #225
This PR combines the model initialisation term as well as individual onnx models together into a serialised estimator.
I will likely only implement the predict_raw method here and leave predict_proba for anther PR as it will require e.g. softmax transforms.