Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/FUNDING.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
github: Gnpd
36 changes: 0 additions & 36 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,39 +114,3 @@ jobs:
run: |
poetry config repositories.testpypi https://test.pypi.org/legacy/
poetry publish --repository testpypi --username __token__ --password $TEST_PYPI_API_TOKEN

update-badge:
needs: publish-test
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Get TestPyPI version
run: |
VERSION=$(curl -s https://test.pypi.org/pypi/openmodels/json | jq -r .info.version) || exit 1
if [ -z "$VERSION" ] || [ "$VERSION" = "null" ]; then
echo "Error: Could not fetch version from TestPyPI." >&2
exit 1
fi
echo "VERSION=$VERSION" >> $GITHUB_ENV
- name: Write Shields JSON
run: |
cat > testpypi-badge.json <<EOF
{
"schemaVersion": 1,
"label": "TestPyPI",
"message": "${VERSION}",
"color": "orange"
}
EOF

- name: Commit badge JSON
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add testpypi-badge.json
git commit -m "Update TestPyPI badge to ${VERSION} [skip ci]" || echo "No changes"
git push https://x-access-token:${GITHUB_TOKEN}@github.com/${GITHUB_REPOSITORY}.git HEAD:main
2 changes: 1 addition & 1 deletion openmodels/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
Currently, it includes a serializer for scikit-learn models.
"""

from .sklearn_serializer import SklearnSerializer
from .sklearn.sklearn_serializer import SklearnSerializer

__all__ = ["SklearnSerializer"]
Empty file.
73 changes: 73 additions & 0 deletions openmodels/serializers/sklearn/_custom_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import inspect
import warnings
from typing import Any, Callable, Dict, List, Tuple, Type, Union


def is_valid_estimator(name: str, cls: Any) -> bool:
"""Check whether (name, cls) represents a valid sklearn estimator."""

if not isinstance(name, str):
return False
if not inspect.isclass(cls):
return False
try:
from sklearn.base import BaseEstimator

return issubclass(cls, BaseEstimator)
except TypeError:
return False


def normalize_estimators(
estimators: Union[Callable[..., Any], List[Any], Tuple[Any, ...], Dict[str, Any]]
) -> List[Any]:
"""Normalize input into a flat list of estimators or (name, class) items."""
if not isinstance(estimators, (list, tuple, set)):
return [estimators]
return list(estimators)


def load_custom_estimators(
custom_estimators: Union[
Callable[..., Any], List[Any], Tuple[Any, ...], Dict[str, Any]
],
all_estimators: Dict[str, Type],
) -> Dict[str, Type]:
"""Convert user-provided estimators into a dictionary of valid ones."""
extra = {}
for est in normalize_estimators(custom_estimators):
try:
items = est() if callable(est) else est
except Exception:
warnings.warn("Failed to call custom_estimator(); skipping.", UserWarning)
continue

if items is None:
continue

if isinstance(items, dict):
iterator = items.items()
else:
iterator = items

for item in iterator:
try:
name, cls = item
except Exception:
warnings.warn(
"Unexpected custom_estimator format; skipping.", UserWarning
)
continue

if not is_valid_estimator(name, cls):
continue

if name in all_estimators and all_estimators[name] is not cls:
warnings.warn(
f"Estimator '{name}' conflicts with built-in one; preferring custom version.",
UserWarning,
)

extra[name] = cls

return extra
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import numpy as np
import inspect

from ._custom_estimator import load_custom_estimators

import sklearn
from sklearn.calibration import _CalibratedClassifier, _SigmoidCalibration
from sklearn.cluster._birch import _CFNode
Expand Down Expand Up @@ -72,15 +74,15 @@
ALL_ESTIMATORS = {
name: cls for name, cls in all_estimators() if issubclass(cls, BaseEstimator)
}
# add _BinMapper to ALL_ESTIMATORS
# add extra private estimators to ALL_ESTIMATORS
ALL_ESTIMATORS["_BinMapper"] = _BinMapper
ALL_ESTIMATORS["_SigmoidCalibration"] = _SigmoidCalibration
ALL_ESTIMATORS["_BinaryGaussianProcessClassifierLaplace"] = (
_BinaryGaussianProcessClassifierLaplace
)
ALL_ESTIMATORS["_ConstantPredictor"] = _ConstantPredictor

TESTED_VERSIONS = ["1.6.1", "1.7.1"]
TESTED_VERSIONS = ["1.6.1", "1.7.2"]

NOT_SUPPORTED_ESTIMATORS: list[str] = [
# Regressors: all regressors work!! Hurray!
Expand Down Expand Up @@ -230,14 +232,63 @@ class SklearnSerializer(
The serializer supports a wide range of scikit-learn estimators and handles
the conversion of numpy arrays and other non-JSON-serializable types.

Attributes
Parameters
----------
SUPPORTED_ESTIMATORS : Dict[str, Type[BaseEstimator]]
A dictionary of supported scikit-learn estimator classes.
SUPPORTED_TYPES : List[Type]
A list of supported types for serialization.
custom_estimators : callable, list, tuple, or dict, optional
Optional collection of third-party or custom estimator classes to support during
serialization and deserialization. This can be:

- A callable returning an iterable or dict of (name, class) pairs (e.g., a function like ``all_estimators``).
- A list or tuple of (name, class) pairs.
- A dict mapping estimator names to their classes.

These estimators are merged into the serializer's internal registry for this instance only,
allowing support for custom or external estimators without affecting the global registry.

See Also
--------
scikit-learn developer guide:
https://scikit-learn.org/stable/developers/develop.html

sklearn.utils.discovery.all_estimators:
https://scikit-learn.org/stable/modules/generated/sklearn.utils.discovery.all_estimators.html

skltemplate.utils.discovery.all_estimators (project template):
https://contrib.scikit-learn.org/project-template/generated/skltemplate.utils.discovery.all_estimators.html

Developer Notes
--------------
For third-party packages compatible with scikit-learn, it is recommended to implement
an ``all_estimators()`` utility following the scikit-learn API and template above.
This enables automatic discovery and integration of custom estimators for serialization.

If you are maintaining a scikit-learn compatible package, let us know!
We are happy to extend our testing to include your estimators, ensuring everything works
smoothly and that we cover any unique types or patterns used in your library.

To request official support for your package, please open an issue at:
https://github.com/SF-Tec/openmodels/issues

"""

def __init__(
self,
custom_estimators: Optional[
Union[
Callable[..., Any],
List[Any],
Tuple[Any, ...],
Dict[str, Type[BaseEstimator]],
]
] = None,
):
extra = (
load_custom_estimators(custom_estimators, ALL_ESTIMATORS)
if custom_estimators
else {}
)
self._all_estimators: Dict[str, Type] = {**ALL_ESTIMATORS, **extra}

# --- Helpers ---
def _check_version(self, stored_version: Optional[str]) -> None:
"""
Expand Down Expand Up @@ -274,10 +325,15 @@ def all_estimators(
"""
Get all scikit-learn supported estimators.

Parameters
----------
type_filter : str, optional
If provided, filter estimators by type (e.g., 'classifier', 'regressor').

Returns
-------
Dict[str, BaseEstimator]
A dictionary of all scikit-learn supported estimators.
list of tuple
List of (name, class) pairs for supported estimators.
"""

return [
Expand Down Expand Up @@ -410,7 +466,7 @@ def _get_deserializer_handlers(self):
]
# Estimators
estimator_handlers = [
(est_name, self.deserialize) for est_name in ALL_ESTIMATORS.keys()
(est_name, self.deserialize) for est_name in self._all_estimators.keys()
]

kernel_handlers = [
Expand Down Expand Up @@ -805,7 +861,7 @@ def serialize(self, model: BaseEstimator) -> Dict[str, Any]:
Raises
------
SerializationError
If the model has not been fitted or if there's an error during serialization.
If there's an error during serialization.

Examples
--------
Expand All @@ -826,7 +882,7 @@ def serialize(self, model: BaseEstimator) -> Dict[str, Any]:
"params": self.convert_to_serializable(params),
"param_types": param_types,
"param_dtypes": param_dtypes,
"producer_version": getattr(model, "_sklearn_version", None),
"producer_version": sklearn.__version__,
"producer_name": model.__module__.split(".")[0],
"domain": "sklearn",
}
Expand Down Expand Up @@ -897,7 +953,7 @@ def deserialize(self, data: Dict[str, Any]) -> BaseEstimator:
params[key] = tuple(value)

# Get valid constructor arguments for the estimator
estimator_cls = ALL_ESTIMATORS[estimator_class]
estimator_cls = self._all_estimators[estimator_class]
valid_args = list(inspect.signature(estimator_cls.__init__).parameters.keys())
# Remove 'self' if present
valid_args = [arg for arg in valid_args if arg != "self"]
Expand Down
2 changes: 1 addition & 1 deletion openmodels/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def run_test_label_binarizer(

def test_multilabelbinarizer_minimal():
from sklearn.preprocessing import MultiLabelBinarizer
from openmodels.serializers.sklearn_serializer import SklearnSerializer
from openmodels.serializers.sklearn.sklearn_serializer import SklearnSerializer
from openmodels import SerializationManager
import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openmodels"
version = "0.1.0-alpha.17"
version = "0.1.0-alpha.18"
description = "Export scikit-learn model files to JSON for sharing or deploying predictive models with peace of mind."
authors = [
"Alejandro Gutierrez <agutierrez@sftec.es>, Pau Cabaneros <pau.cabaneros@gmail.com>, Raúl Marín <hi@raulmarin.dev>, Ruben Parrilla <rparrilla@sftec.es>",
Expand Down
2 changes: 1 addition & 1 deletion test/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.utils.discovery import all_estimators
from sklearn.datasets import make_classification
from openmodels.test_helpers import run_test_model
from openmodels.serializers.sklearn_serializer import NOT_SUPPORTED_ESTIMATORS
from openmodels.serializers.sklearn.sklearn_serializer import NOT_SUPPORTED_ESTIMATORS

# Get all classifier estimators, filtering out not supported classifiers
CLASSIFIERS = [cls for name, cls in all_estimators(type_filter="classifier")
Expand Down
2 changes: 1 addition & 1 deletion test/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.datasets import make_blobs
from sklearn.feature_extraction import FeatureHasher
from openmodels.test_helpers import run_test_model
from openmodels.serializers.sklearn_serializer import NOT_SUPPORTED_ESTIMATORS
from openmodels.serializers.sklearn.sklearn_serializer import NOT_SUPPORTED_ESTIMATORS

# Get all cluster estimators, filtering out not supported clusters
CLUSTERS = [
Expand Down
Loading