From 2b6f126b1bdee97a1e929d2a02cef219ad4adf05 Mon Sep 17 00:00:00 2001 From: Snomaan6846 Date: Fri, 23 Jan 2026 15:07:56 +0530 Subject: [PATCH] fix: Handle bind mounts and delayed model repository availability in KServe deployments Signed-off-by: Snomaan6846 rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED --- docs-gb/user-guide/deployment/kserve.md | 46 +++ docs/user-guide/deployment/kserve.md | 46 +++ mlserver/model.py | 42 ++- mlserver/repository/factory.py | 21 +- mlserver/repository/repository.py | 50 ++- mlserver/settings.py | 6 + mlserver/utils.py | 55 +++ .../mlserver_alibi_detect/runtime.py | 12 +- .../mlserver_alibi_explain/runtime.py | 12 +- .../catboost/mlserver_catboost/catboost.py | 13 +- .../mlserver_huggingface/common.py | 29 +- .../mlserver_huggingface/runtime.py | 30 +- .../lightgbm/mlserver_lightgbm/lightgbm.py | 13 +- runtimes/mlflow/mlserver_mlflow/runtime.py | 12 +- runtimes/mllib/mlserver_mllib/mllib.py | 12 +- runtimes/sklearn/mlserver_sklearn/sklearn.py | 12 +- runtimes/xgboost/mlserver_xgboost/xgboost.py | 13 +- tests/repository/test_repository.py | 91 +++++ tests/test_model_cleanup.py | 283 ++++++++++++++++ tests/test_utils.py | 313 +++++++++++++++++- 20 files changed, 1064 insertions(+), 47 deletions(-) create mode 100644 tests/test_model_cleanup.py diff --git a/docs-gb/user-guide/deployment/kserve.md b/docs-gb/user-guide/deployment/kserve.md index ad46f795f..92fa4d775 100644 --- a/docs-gb/user-guide/deployment/kserve.md +++ b/docs-gb/user-guide/deployment/kserve.md @@ -130,3 +130,49 @@ it directly through `kubectl`, by running: ```bash kubectl apply -f my-inferenceservice-manifest.yaml ``` + +## Advanced Configuration + +### Model Repository Availability + +When deploying with KServe, models may be mounted from external storage (e.g., S3, PVC, or OCI images). +In some scenarios, there may be a delay between when MLServer starts and when the model repository becomes available. +MLServer provides configuration options to handle such scenarios gracefully: + +| Setting | Environment Variable | Default | Description | +|---------|---------------------|---------|-------------| +| `model_repository_retries` | `MLSERVER_MODEL_REPOSITORY_RETRIES` | `10` | Number of retries to wait for model repository to become available | +| `model_repository_wait_interval` | `MLSERVER_MODEL_REPOSITORY_WAIT_INTERVAL` | `1.0` | Wait interval (in seconds) between retries | + +These settings can be configured via environment variables in your `InferenceService` manifest: + +```yaml +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + name: my-model +spec: + predictor: + sklearn: + protocolVersion: v2 + storageUri: gs://seldon-models/sklearn/iris + env: + - name: MLSERVER_MODEL_REPOSITORY_RETRIES + value: "20" + - name: MLSERVER_MODEL_REPOSITORY_WAIT_INTERVAL + value: "2.0" +``` + +Or via a `settings.json` file in your model repository: + +```json +{ + "model_repository_retries": 20, + "model_repository_wait_interval": 2.0 +} +``` + +This is particularly useful when working with: +- **OCI model images** where the model sidecar may take time to mount files +- **Network storage** where connectivity or initialization delays may occur +- **Large models** where the download or extraction process takes time diff --git a/docs/user-guide/deployment/kserve.md b/docs/user-guide/deployment/kserve.md index fe7066011..db555457d 100644 --- a/docs/user-guide/deployment/kserve.md +++ b/docs/user-guide/deployment/kserve.md @@ -147,3 +147,49 @@ it directly through `kubectl`, by running: ```bash kubectl apply -f my-inferenceservice-manifest.yaml ``` + +## Advanced Configuration + +### Model Repository Availability + +When deploying with KServe, models may be mounted from external storage (e.g., S3, PVC, or OCI images). +In some scenarios, there may be a delay between when MLServer starts and when the model repository becomes available. +MLServer provides configuration options to handle such scenarios gracefully: + +| Setting | Environment Variable | Default | Description | +|---------|---------------------|---------|-------------| +| `model_repository_retries` | `MLSERVER_MODEL_REPOSITORY_RETRIES` | `10` | Number of retries to wait for model repository to become available | +| `model_repository_wait_interval` | `MLSERVER_MODEL_REPOSITORY_WAIT_INTERVAL` | `1.0` | Wait interval (in seconds) between retries | + +These settings can be configured via environment variables in your `InferenceService` manifest: + +```yaml +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + name: my-model +spec: + predictor: + sklearn: + protocolVersion: v2 + storageUri: gs://seldon-models/sklearn/iris + env: + - name: MLSERVER_MODEL_REPOSITORY_RETRIES + value: "20" + - name: MLSERVER_MODEL_REPOSITORY_WAIT_INTERVAL + value: "2.0" +``` + +Or via a `settings.json` file in your model repository: + +```json +{ + "model_repository_retries": 20, + "model_repository_wait_interval": 2.0 +} +``` + +This is particularly useful when working with: +- **OCI model images** where the model sidecar may take time to mount files +- **Network storage** where connectivity or initialization delays may occur +- **Large models** where the download or extraction process takes time diff --git a/mlserver/model.py b/mlserver/model.py index 8296bc19e..573cbe7f9 100644 --- a/mlserver/model.py +++ b/mlserver/model.py @@ -1,3 +1,6 @@ +import os +import logging + from typing import Any, Dict, Optional, List, AsyncIterator from .codecs import ( @@ -25,6 +28,8 @@ Parameters, ) +logger = logging.getLogger(__name__) + def _generate_metadata_index( metadata_tensors: Optional[List[MetadataTensor]], @@ -52,6 +57,9 @@ def __init__(self, settings: ModelSettings): self._inputs_index = _generate_metadata_index(self._settings.inputs) self._outputs_index = _generate_metadata_index(self._settings.outputs) + + # Track transient model files for automatic cleanup on unload + self._transient_model_files: List[str] = [] self.ready = False @@ -91,6 +99,21 @@ async def predict_stream( logic.** """ yield await self.predict((await payloads.__anext__())) + + def register_transient_file(self, file_path: str) -> None: + """ + Register a transient model file for automatic cleanup when the model is unloaded. + + Transient files are created when model artifacts need to be copied from + incompatible filesystem mounts (e.g., bind mounts, proc paths) to local + storage for runtime compatibility. These files are automatically removed + when the model is unloaded. + + Args: + file_path: Path to the transient model file to track for cleanup + """ + if file_path and file_path not in self._transient_model_files: + self._transient_model_files.append(file_path) async def unload(self) -> bool: """ @@ -100,10 +123,27 @@ async def unload(self) -> bool: :doc:`parallel inference `) is enabled). A return value of ``True`` will mean the model is now unloaded. + + This base implementation automatically cleans up any transient model files + registered via ``register_transient_file()``. **This method can be overriden to implement your custom unload - logic.** + logic. If you override this method, call super().unload() to ensure + transient files are cleaned up.** """ + # Clean up transient model files + for transient_file in self._transient_model_files: + try: + if os.path.exists(transient_file): + os.remove(transient_file) + logger.debug(f"Cleaned up transient model file: {transient_file}") + except Exception as e: + # Log but don't fail unload if cleanup fails + logger.warning( + f"Failed to cleanup transient file {transient_file}: {e}" + ) + + self._transient_model_files.clear() return True @property diff --git a/mlserver/repository/factory.py b/mlserver/repository/factory.py index 863dd94f5..a92bc0a9a 100644 --- a/mlserver/repository/factory.py +++ b/mlserver/repository/factory.py @@ -1,6 +1,7 @@ from .repository import ModelRepository, SchemalessModelRepository from ..settings import Settings from pydantic import ImportString +import inspect class ModelRepositoryFactory: @@ -12,9 +13,21 @@ def resolve_model_repository(settings: Settings) -> ModelRepository: if settings.model_repository_implementation: model_repository_implementation = settings.model_repository_implementation - result = model_repository_implementation( - root=settings.model_repository_root, - **settings.model_repository_implementation_args, - ) + # Check if the repository constructor accepts 'settings' parameter + sig = inspect.signature(model_repository_implementation.__init__) + accepts_settings = 'settings' in sig.parameters + + if accepts_settings: + result = model_repository_implementation( + root=settings.model_repository_root, + settings=settings, + **settings.model_repository_implementation_args, + ) + else: + # Backward compatibility: don't pass settings if not accepted + result = model_repository_implementation( + root=settings.model_repository_root, + **settings.model_repository_implementation_args, + ) return result diff --git a/mlserver/repository/repository.py b/mlserver/repository/repository.py index 17f7fe19c..a8aa50333 100644 --- a/mlserver/repository/repository.py +++ b/mlserver/repository/repository.py @@ -1,11 +1,12 @@ import abc import os import glob +import time from pydantic import ValidationError -from typing import List +from typing import List, Optional -from ..settings import ModelParameters, ModelSettings +from ..settings import ModelParameters, ModelSettings, Settings from ..errors import ModelNotFound from ..logging import logger @@ -30,17 +31,58 @@ class SchemalessModelRepository(ModelRepository): loaded onto the model registry. """ - def __init__(self, root: str): + def __init__(self, root: str, settings: Optional[Settings] = None): self._root = root + self._settings = settings + + # Get retry configuration from settings or use defaults + if self._settings: + self._retries = self._settings.model_repository_retries + self._wait_interval = self._settings.model_repository_wait_interval + else: + # Fallback to env vars for backward compatibility + self._retries = int(os.environ.get("MLSERVER_MODEL_REPOSITORY_RETRIES", "10")) + self._wait_interval = float(os.environ.get("MLSERVER_MODEL_REPOSITORY_WAIT_INTERVAL", "1.0")) + + def _wait_for_repository(self, path: str) -> str: + """Wait for model repository path to become available.""" + abs_path = os.path.abspath(path) + + if os.path.exists(abs_path): + return abs_path + + if self._retries <= 0: + return abs_path + + logger.info( + f"Waiting up to {self._retries * self._wait_interval}s " + f"for model repository to become available..." + ) + + for attempt in range(self._retries): + time.sleep(self._wait_interval) + + if os.path.exists(abs_path): + logger.info(f"Model repository is now available") + time.sleep(0.5) # Brief delay for files to stabilize + return abs_path + + logger.warning(f"Model repository still not available after waiting") + return abs_path async def list(self) -> List[ModelSettings]: all_model_settings = [] # TODO: Use an async alternative for filesys ops if self._root: - abs_root = os.path.abspath(self._root) + # Wait for model repository to become available + abs_root = self._wait_for_repository(self._root) + pattern = os.path.join(abs_root, "**", DEFAULT_MODEL_SETTINGS_FILENAME) matches = glob.glob(pattern, recursive=True) + + if len(matches) == 0: + logger.warning(f"No model-settings.json found in {self._root}") for model_settings_path in matches: try: diff --git a/mlserver/settings.py b/mlserver/settings.py index 3c5501d5d..b9fc95182 100644 --- a/mlserver/settings.py +++ b/mlserver/settings.py @@ -190,6 +190,12 @@ class Settings(BaseSettings): model_repository_root: str = "." """Root of the model repository, where we will search for models.""" + model_repository_retries: int = 10 + """Number of retries to wait for model repository to become available.""" + + model_repository_wait_interval: float = 1.0 + """Wait interval (in seconds) between retries for model repository.""" + # Model Repository parameters are meant to be set directly by the MLServer runtime. model_repository_implementation_args: dict = {} """Extra parameters for model repository.""" diff --git a/mlserver/utils.py b/mlserver/utils.py index 34217aae5..d6e0a4d09 100644 --- a/mlserver/utils.py +++ b/mlserver/utils.py @@ -12,6 +12,8 @@ from typing import cast import warnings import urllib.parse +import shutil +import tempfile from asyncio import Task from typing import Any, Optional, TypeVar @@ -223,6 +225,59 @@ async def get_model_uri( raise InvalidModelURI(settings.name, full_model_path) +def ensure_local_path(file_path: str) -> str: + """ + Ensure a file path is on local filesystem accessible by model runtimes. + + Some runtimes may have issues loading from certain filesystem types + (e.g., special mounts, network filesystems). This function returns a + path that is guaranteed to be on local storage. + + If the file is already on accessible storage, returns the original path. + Otherwise, copies to temporary storage and returns the temp path. + + Note: Caller is responsible for cleanup of temporary files. + + Args: + file_path: Original file path + + Returns: + Path to file on local filesystem + """ + # Check if file is accessible + if not os.path.exists(file_path): + return file_path + + # Check if path contains indicators of special mounts + realpath = os.path.realpath(file_path) + normalized_input = os.path.normpath(file_path) + + # Detect problematic filesystem configurations: + # 1. Path contains /proc/ (proc-based mounts) + # 2. Path resolves to a different location (bind mounts, symlinks) + # 3. Broken symlinks + needs_copy = ( + '/proc/' in realpath or + realpath != normalized_input or + (os.path.islink(file_path) and not os.path.exists(realpath)) + ) + + if not needs_copy: + return file_path + + # Copy to temp location + logger.info(f"Copying model file to local temporary storage for runtime compatibility") + + filename = os.path.basename(file_path) + suffix = os.path.splitext(filename)[1] + temp_fd, temp_path = tempfile.mkstemp(suffix=suffix, prefix='mlserver_') + os.close(temp_fd) + + shutil.copy2(file_path, temp_path) + logger.debug(f"Model copied from {file_path} to {temp_path}") + return temp_path + + def to_absolute_path(model_settings: ModelSettings, uri: str) -> str: source = model_settings._source if source is None: diff --git a/runtimes/alibi-detect/mlserver_alibi_detect/runtime.py b/runtimes/alibi-detect/mlserver_alibi_detect/runtime.py index d746e6ac5..e8bc73c2e 100644 --- a/runtimes/alibi-detect/mlserver_alibi_detect/runtime.py +++ b/runtimes/alibi-detect/mlserver_alibi_detect/runtime.py @@ -15,7 +15,7 @@ from mlserver.settings import ModelSettings from mlserver.model import MLModel from mlserver.codecs import NumpyCodec, NumpyRequestCodec -from mlserver.utils import get_model_uri +from mlserver.utils import get_model_uri, ensure_local_path from mlserver.errors import MLServerError, InferenceError from mlserver.logging import logger @@ -70,8 +70,16 @@ def __init__(self, settings: ModelSettings): async def load(self) -> bool: self._model_uri = await get_model_uri(self._settings) + + # Ensure model is on local filesystem for compatibility + local_model_uri = ensure_local_path(self._model_uri) + + # Register transient file for automatic cleanup on unload + if local_model_uri != self._model_uri: + self.register_transient_file(local_model_uri) + try: - self._model = load_detector(self._model_uri) + self._model = load_detector(local_model_uri) mlserver.register("seldon_model_drift", "Drift metrics") # Check whether an online drift detector (i.e. has a save_state method) diff --git a/runtimes/alibi-explain/mlserver_alibi_explain/runtime.py b/runtimes/alibi-explain/mlserver_alibi_explain/runtime.py index 9fc842eed..86ca1833f 100644 --- a/runtimes/alibi-explain/mlserver_alibi_explain/runtime.py +++ b/runtimes/alibi-explain/mlserver_alibi_explain/runtime.py @@ -24,7 +24,7 @@ Parameters, ResponseOutput, ) -from mlserver.utils import get_model_uri +from mlserver.utils import get_model_uri, ensure_local_path from mlserver_alibi_explain.alibi_dependency_reference import ( get_mlmodel_class_as_str, get_alibi_class_as_str, @@ -128,7 +128,15 @@ async def _load_from_uri(self, predictor: Any) -> Explainer: if model_parameters is None: raise ModelParametersMissing(self.name) absolute_uri = await get_model_uri(self.settings) - return await self._load_explainer(path=absolute_uri, predictor=predictor) + + # Ensure model is on local filesystem for compatibility + local_uri = ensure_local_path(absolute_uri) + + # Register transient file for automatic cleanup on unload + if local_uri != absolute_uri: + self.register_transient_file(local_uri) + + return await self._load_explainer(path=local_uri, predictor=predictor) async def _load_explainer(self, path: str, predictor: Any) -> Explainer: loop = asyncio.get_running_loop() diff --git a/runtimes/catboost/mlserver_catboost/catboost.py b/runtimes/catboost/mlserver_catboost/catboost.py index c8cb2e856..793df41c5 100644 --- a/runtimes/catboost/mlserver_catboost/catboost.py +++ b/runtimes/catboost/mlserver_catboost/catboost.py @@ -2,7 +2,7 @@ from mlserver import types from mlserver.model import MLModel -from mlserver.utils import get_model_uri +from mlserver.utils import get_model_uri, ensure_local_path from mlserver.codecs import NumpyCodec, NumpyRequestCodec @@ -18,9 +18,16 @@ async def load(self) -> bool: model_uri = await get_model_uri( self._settings, wellknown_filenames=WELLKNOWN_MODEL_FILENAMES ) - + + # Ensure model is on local filesystem for compatibility + local_model_uri = ensure_local_path(model_uri) + + # Register transient file for automatic cleanup on unload + if local_model_uri != model_uri: + self.register_transient_file(local_model_uri) + self._model = CatBoostClassifier() - self._model.load_model(model_uri) + self._model.load_model(local_model_uri) self.ready = True return self.ready diff --git a/runtimes/huggingface/mlserver_huggingface/common.py b/runtimes/huggingface/mlserver_huggingface/common.py index e53d7f6e2..a6a00eea6 100644 --- a/runtimes/huggingface/mlserver_huggingface/common.py +++ b/runtimes/huggingface/mlserver_huggingface/common.py @@ -5,6 +5,7 @@ from functools import partial from mlserver.logging import logger from mlserver.settings import ModelSettings +from mlserver.utils import ensure_local_path import torch import tensorflow as tf @@ -23,15 +24,33 @@ def load_pipeline_from_settings( hf_settings: HuggingFaceSettings, settings: ModelSettings -) -> Pipeline: +) -> tuple[Pipeline, str | None]: + """ + Load HuggingFace pipeline from settings. + + Returns: + Tuple of (pipeline, transient_file_path) where transient_file_path + is the path to cleanup if a local copy was created, or None if using + the original path. + """ pipeline = _get_pipeline_class(hf_settings) batch_size = 1 if settings.max_batch_size: batch_size = settings.max_batch_size - model = hf_settings.pretrained_model - if not model: - model = settings.parameters.uri # type: ignore + original_model = hf_settings.pretrained_model + transient_file = None + + if not original_model: + original_model = settings.parameters.uri # type: ignore + # Ensure model path is on local filesystem for compatibility + if original_model and not original_model.startswith(("http://", "https://")): + local_model = ensure_local_path(original_model) + if local_model != original_model: + transient_file = local_model + original_model = local_model + + model = original_model tokenizer = hf_settings.pretrained_tokenizer if not tokenizer: tokenizer = hf_settings.pretrained_model @@ -73,7 +92,7 @@ def load_pipeline_from_settings( ) hf_pipeline._batch_size = 1 - return hf_pipeline + return hf_pipeline, transient_file def _get_pipeline_class(hf_settings: HuggingFaceSettings) -> _PipelineConstructor: diff --git a/runtimes/huggingface/mlserver_huggingface/runtime.py b/runtimes/huggingface/mlserver_huggingface/runtime.py index 2e1316ddb..3affe341a 100644 --- a/runtimes/huggingface/mlserver_huggingface/runtime.py +++ b/runtimes/huggingface/mlserver_huggingface/runtime.py @@ -33,7 +33,14 @@ async def load(self) -> bool: ) # Now we load the cached model which should not block asyncio - self._model = load_pipeline_from_settings(self.hf_settings, self.settings) + self._model, transient_file = load_pipeline_from_settings( + self.hf_settings, self.settings + ) + + # Register transient file for automatic cleanup on unload + if transient_file: + self.register_transient_file(transient_file) + self._merge_metadata() return True @@ -52,19 +59,16 @@ async def predict(self, payload: InferenceRequest) -> InferenceResponse: ) async def unload(self) -> bool: - # TODO: Free up Tensorflow's GPU memory + # Free up GPU memory if using Torch is_torch = self._model.framework == "pt" - if not is_torch: - return True - - uses_gpu = torch.cuda.is_available() and self._model.device != -1 - if not uses_gpu: - # Nothing to free - return True - - # Free up Torch's GPU memory - torch.cuda.empty_cache() - return True + if is_torch: + uses_gpu = torch.cuda.is_available() and self._model.device != -1 + if uses_gpu: + # Free up Torch's GPU memory + torch.cuda.empty_cache() + + # Call parent to cleanup transient files + return await super().unload() def _merge_metadata(self) -> None: meta = METADATA.get(self.hf_settings.task) diff --git a/runtimes/lightgbm/mlserver_lightgbm/lightgbm.py b/runtimes/lightgbm/mlserver_lightgbm/lightgbm.py index 420cb8184..8de98b63e 100644 --- a/runtimes/lightgbm/mlserver_lightgbm/lightgbm.py +++ b/runtimes/lightgbm/mlserver_lightgbm/lightgbm.py @@ -2,7 +2,7 @@ from mlserver import types from mlserver.model import MLModel -from mlserver.utils import get_model_uri +from mlserver.utils import get_model_uri, ensure_local_path from mlserver.codecs import NumpyCodec, NumpyRequestCodec @@ -18,8 +18,15 @@ async def load(self) -> bool: model_uri = await get_model_uri( self._settings, wellknown_filenames=WELLKNOWN_MODEL_FILENAMES ) - - self._model = lgb.Booster(model_file=model_uri) + + # Ensure model is on local filesystem for compatibility + local_model_uri = ensure_local_path(model_uri) + + # Register transient file for automatic cleanup on unload + if local_model_uri != model_uri: + self.register_transient_file(local_model_uri) + + self._model = lgb.Booster(model_file=local_model_uri) return True diff --git a/runtimes/mlflow/mlserver_mlflow/runtime.py b/runtimes/mlflow/mlserver_mlflow/runtime.py index ffb322fe0..2bb3c773d 100644 --- a/runtimes/mlflow/mlserver_mlflow/runtime.py +++ b/runtimes/mlflow/mlserver_mlflow/runtime.py @@ -17,7 +17,7 @@ from mlserver.types import InferenceRequest, InferenceResponse from mlserver.model import MLModel -from mlserver.utils import get_model_uri +from mlserver.utils import get_model_uri, ensure_local_path from mlserver.handlers import custom_handler from mlserver.errors import InferenceError from mlserver.settings import ModelParameters @@ -155,7 +155,15 @@ async def invocations( async def load(self) -> bool: # TODO: Log info message model_uri = await get_model_uri(self._settings) - self._model = mlflow.pyfunc.load_model(model_uri) + + # Ensure model is on local filesystem for compatibility + local_model_uri = ensure_local_path(model_uri) + + # Register transient file for automatic cleanup on unload + if local_model_uri != model_uri: + self.register_transient_file(local_model_uri) + + self._model = mlflow.pyfunc.load_model(local_model_uri) self._input_schema = self._model.metadata.get_input_schema() self._signature = self._model.metadata.signature diff --git a/runtimes/mllib/mlserver_mllib/mllib.py b/runtimes/mllib/mlserver_mllib/mllib.py index 35287fd68..879cb9e70 100644 --- a/runtimes/mllib/mlserver_mllib/mllib.py +++ b/runtimes/mllib/mlserver_mllib/mllib.py @@ -1,5 +1,5 @@ from mlserver import MLModel, types -from mlserver.utils import get_model_uri +from mlserver.utils import get_model_uri, ensure_local_path from mlserver.errors import InferenceError from pyspark import SparkContext, SparkConf @@ -14,9 +14,17 @@ async def load(self) -> bool: sc = SparkContext(appName="MLlibModel", conf=conf) model_uri = await get_model_uri(self._settings) + + # Ensure model is on local filesystem for compatibility + local_model_uri = ensure_local_path(model_uri) + + # Register transient file for automatic cleanup on unload + if local_model_uri != model_uri: + self.register_transient_file(local_model_uri) + model_load = await get_mllib_load(self._settings) - self._model = model_load(sc, model_uri) + self._model = model_load(sc, local_model_uri) return True diff --git a/runtimes/sklearn/mlserver_sklearn/sklearn.py b/runtimes/sklearn/mlserver_sklearn/sklearn.py index e722eb926..17bc54221 100644 --- a/runtimes/sklearn/mlserver_sklearn/sklearn.py +++ b/runtimes/sklearn/mlserver_sklearn/sklearn.py @@ -12,7 +12,7 @@ ResponseOutput, RequestOutput, ) -from mlserver.utils import get_model_uri +from mlserver.utils import get_model_uri, ensure_local_path PREDICT_FN_KEY = "predict_fn" PREDICT_OUTPUT = "predict" @@ -34,7 +34,15 @@ async def load(self) -> bool: model_uri = await get_model_uri( self._settings, wellknown_filenames=WELLKNOWN_MODEL_FILENAMES ) - self._model = joblib.load(model_uri) + + # Ensure model is on local filesystem for compatibility + local_model_uri = ensure_local_path(model_uri) + + # Register transient file for automatic cleanup on unload + if local_model_uri != model_uri: + self.register_transient_file(local_model_uri) + + self._model = joblib.load(local_model_uri) return True diff --git a/runtimes/xgboost/mlserver_xgboost/xgboost.py b/runtimes/xgboost/mlserver_xgboost/xgboost.py index 9e97fe132..e03edbfcf 100644 --- a/runtimes/xgboost/mlserver_xgboost/xgboost.py +++ b/runtimes/xgboost/mlserver_xgboost/xgboost.py @@ -5,7 +5,7 @@ from mlserver.errors import InferenceError from mlserver.model import MLModel -from mlserver.utils import get_model_uri +from mlserver.utils import get_model_uri, ensure_local_path from mlserver.codecs import NumpyRequestCodec, NumpyCodec from mlserver.types import ( InferenceRequest, @@ -43,8 +43,15 @@ async def load(self) -> bool: model_uri = await get_model_uri( self._settings, wellknown_filenames=WELLKNOWN_MODEL_FILENAMES ) - - self._model = _load_sklearn_interface(model_uri) + + # Ensure model is on local filesystem for XGBoost compatibility + local_model_uri = ensure_local_path(model_uri) + + # Register transient file for automatic cleanup on unload + if local_model_uri != model_uri: + self.register_transient_file(local_model_uri) + + self._model = _load_sklearn_interface(local_model_uri) return True diff --git a/tests/repository/test_repository.py b/tests/repository/test_repository.py index 6e8e9586a..0d3a74beb 100644 --- a/tests/repository/test_repository.py +++ b/tests/repository/test_repository.py @@ -1,5 +1,8 @@ import os import pytest +import time + +from unittest.mock import patch, MagicMock from mlserver.repository import ( ModelRepository, @@ -115,3 +118,91 @@ async def test_find( assert len(found_model_settings) == 1 assert found_model_settings[0].name == sum_model_settings.name + + +class TestRepositoryRetryLogic: + """Tests for model repository retry logic (delayed mount handling)""" + + async def test_wait_for_repository_immediate_success(self, tmp_path): + """Test that immediate path existence doesn't wait""" + repo = SchemalessModelRepository(str(tmp_path), settings=None) + + start_time = time.time() + result = repo._wait_for_repository(str(tmp_path)) + elapsed = time.time() - start_time + + assert result == str(tmp_path) + assert elapsed < 0.5 # Should be instant + + async def test_wait_for_repository_timeout(self, tmp_path): + """Test that repository gives up after max retries""" + nonexistent = tmp_path / "never_appears" + repo = SchemalessModelRepository(str(nonexistent), settings=None) + + # Override retry settings + repo._retries = 3 + repo._wait_interval = 0.1 + + with patch('os.path.exists', return_value=False): + start_time = time.time() + result = repo._wait_for_repository(str(nonexistent)) + elapsed = time.time() - start_time + + # Should have waited ~0.3s (3 retries * 0.1s) + assert 0.25 < elapsed < 0.5 + # Should still return the path (caller handles non-existence) + assert result == str(nonexistent) + + async def test_wait_for_repository_zero_retries(self, tmp_path): + """Test that with zero retries, no waiting occurs""" + nonexistent = tmp_path / "immediate_fail" + repo = SchemalessModelRepository(str(nonexistent), settings=None) + + # Override retry settings + repo._retries = 0 + + with patch('os.path.exists', return_value=False): + start_time = time.time() + result = repo._wait_for_repository(str(nonexistent)) + elapsed = time.time() - start_time + + # Should be instant + assert elapsed < 0.1 + assert result == str(nonexistent) + + async def test_list_with_delayed_mount(self, tmp_path): + """Test that list() uses retry logic for delayed mounts""" + model_dir = tmp_path / "models" + repo = SchemalessModelRepository(str(model_dir), settings=None) + + # Override retry settings + repo._retries = 5 + repo._wait_interval = 0.1 + + # Simulate delayed mount by creating directory after delay + call_count = 0 + original_exists = os.path.exists + + def mock_exists(path): + nonlocal call_count + call_count += 1 + + # Return False for first 2 calls, then create and return True + if call_count <= 2: + return False + + # Create the directory structure on 3rd call + if not original_exists(path): + os.makedirs(path, exist_ok=True) + # Create a model-settings.json + settings_file = os.path.join(path, DEFAULT_MODEL_SETTINGS_FILENAME) + with open(settings_file, 'w') as f: + f.write('{"name": "test-model", "implementation": "mlserver.MLModel"}') + return True + + with patch('os.path.exists', side_effect=mock_exists): + settings_list = await repo.list() + + # Should successfully find the model after waiting + assert len(settings_list) == 1 + assert settings_list[0].name == "test-model" diff --git a/tests/test_model_cleanup.py b/tests/test_model_cleanup.py new file mode 100644 index 000000000..ef3048d98 --- /dev/null +++ b/tests/test_model_cleanup.py @@ -0,0 +1,283 @@ +""" +Tests for model transient file cleanup functionality. +""" +import pytest +import os +import tempfile +from unittest.mock import patch + +from mlserver.model import MLModel +from mlserver.settings import ModelSettings +from mlserver.types import InferenceRequest, InferenceResponse + + +class SimpleTestModel(MLModel): + """Simple test model for cleanup testing""" + + async def load(self) -> bool: + return True + + async def predict(self, payload: InferenceRequest) -> InferenceResponse: + return InferenceResponse( + model_name=self.name, + outputs=[] + ) + + +@pytest.fixture +def model_settings(): + return ModelSettings(name="test-model", implementation="test") + + +@pytest.fixture +def test_model(model_settings): + return SimpleTestModel(model_settings) + + +class TestModelTransientFileCleanup: + """Test suite for MLModel transient file cleanup""" + + def test_register_transient_file(self, test_model): + """Test that transient files can be registered""" + transient_file = "/tmp/test_model.json" + + test_model.register_transient_file(transient_file) + + assert transient_file in test_model._transient_model_files + + def test_register_multiple_transient_files(self, test_model): + """Test that multiple transient files can be registered""" + transient_files = [ + "/tmp/model1.json", + "/tmp/model2.bst", + "/tmp/model3.pkl" + ] + + for transient_file in transient_files: + test_model.register_transient_file(transient_file) + + assert len(test_model._transient_model_files) == 3 + for transient_file in transient_files: + assert transient_file in test_model._transient_model_files + + def test_register_duplicate_transient_file_ignored(self, test_model): + """Test that duplicate transient file registrations are ignored""" + transient_file = "/tmp/model.json" + + test_model.register_transient_file(transient_file) + test_model.register_transient_file(transient_file) + test_model.register_transient_file(transient_file) + + # Should only be registered once + assert test_model._transient_model_files.count(transient_file) == 1 + + def test_register_empty_path_ignored(self, test_model): + """Test that empty paths are ignored""" + test_model.register_transient_file("") + test_model.register_transient_file(None) + + assert len(test_model._transient_model_files) == 0 + + async def test_unload_cleans_up_transient_files(self, test_model, tmp_path): + """Test that unload removes registered transient files""" + # Create actual transient files + transient_file1 = tmp_path / "model1.json" + transient_file2 = tmp_path / "model2.bst" + + transient_file1.write_text('{"test": 1}') + transient_file2.write_text("xgboost model") + + # Register them + test_model.register_transient_file(str(transient_file1)) + test_model.register_transient_file(str(transient_file2)) + + # Verify they exist + assert os.path.exists(transient_file1) + assert os.path.exists(transient_file2) + + # Unload should clean them up + await test_model.unload() + + # Verify they're deleted + assert not os.path.exists(transient_file1) + assert not os.path.exists(transient_file2) + + # Verify list is cleared + assert len(test_model._transient_model_files) == 0 + + async def test_unload_with_nonexistent_file(self, test_model): + """Test that unload handles non-existent files gracefully""" + test_model.register_transient_file("/tmp/nonexistent_file.json") + + # Should not raise error + result = await test_model.unload() + + assert result is True + assert len(test_model._transient_model_files) == 0 + + async def test_unload_with_permission_error(self, test_model, tmp_path): + """Test that unload handles permission errors gracefully""" + transient_file = tmp_path / "readonly_model.json" + transient_file.write_text('{"test": "data"}') + + test_model.register_transient_file(str(transient_file)) + + # Mock os.remove to raise PermissionError + with patch('os.remove', side_effect=PermissionError("Access denied")): + # Should not raise error, just log warning + result = await test_model.unload() + + assert result is True + # List should still be cleared even if deletion failed + assert len(test_model._transient_model_files) == 0 + + async def test_unload_partial_failure(self, test_model, tmp_path): + """Test that unload continues even if some files fail to delete""" + # Create two transient files + transient_file1 = tmp_path / "model1.json" + transient_file2 = tmp_path / "model2.json" + + transient_file1.write_text('{"test": 1}') + transient_file2.write_text('{"test": 2}') + + test_model.register_transient_file(str(transient_file1)) + test_model.register_transient_file(str(transient_file2)) + + # Mock os.remove to fail only for first file + original_remove = os.remove + def mock_remove(path): + if str(transient_file1) in path: + raise OSError("Failed to delete") + original_remove(path) + + with patch('os.remove', side_effect=mock_remove): + result = await test_model.unload() + + assert result is True + # Second file should still be deleted + assert not os.path.exists(transient_file2) + + async def test_model_lifecycle_with_transient_files(self, test_model, tmp_path): + """Test complete model lifecycle with transient file management""" + # Load + await test_model.load() + + # Simulate transient file creation during load + transient_file = tmp_path / "model.json" + transient_file.write_text('{"model": "data"}') + test_model.register_transient_file(str(transient_file)) + + # Model ready + test_model.ready = True + assert test_model.ready + assert os.path.exists(transient_file) + + # Unload + await test_model.unload() + + # Transient file should be gone + assert not os.path.exists(transient_file) + + async def test_multiple_load_unload_cycles(self, test_model, tmp_path): + """Test multiple load/unload cycles""" + for i in range(3): + # Create transient file + transient_file = tmp_path / f"model_{i}.json" + transient_file.write_text(f'{{"cycle": {i}}}') + + # Register + test_model.register_transient_file(str(transient_file)) + assert os.path.exists(transient_file) + + # Unload + await test_model.unload() + assert not os.path.exists(transient_file) + assert len(test_model._transient_model_files) == 0 + + +class TestRuntimeIntegration: + """Test cleanup integration with actual runtimes""" + + async def test_no_cleanup_when_no_transient_files(self, test_model): + """Test that models without transient files work normally""" + await test_model.load() + + # Don't register any transient files + result = await test_model.unload() + + assert result is True + + async def test_multiple_models_independent_cleanup(self, model_settings, tmp_path): + """Test that multiple model instances have independent transient file lists""" + + # Create two separate model instances + model_a = SimpleTestModel(ModelSettings(name="model-a", implementation="test")) + model_b = SimpleTestModel(ModelSettings(name="model-b", implementation="test")) + + # Create transient files for each model + transient_file_a = tmp_path / "model_a.json" + transient_file_b = tmp_path / "model_b.json" + + transient_file_a.write_text('{"model": "a"}') + transient_file_b.write_text('{"model": "b"}') + + # Register transient files to respective models + model_a.register_transient_file(str(transient_file_a)) + model_b.register_transient_file(str(transient_file_b)) + + # Verify both files exist + assert os.path.exists(transient_file_a) + assert os.path.exists(transient_file_b) + + # Verify each model only tracks its own transient file + assert len(model_a._transient_model_files) == 1 + assert len(model_b._transient_model_files) == 1 + assert str(transient_file_a) in model_a._transient_model_files + assert str(transient_file_b) in model_b._transient_model_files + + # Unload model_a + await model_a.unload() + + # ✅ Only model_a's transient file should be deleted + assert not os.path.exists(transient_file_a) + assert os.path.exists(transient_file_b) # ✅ model_b's file still exists! + + # ✅ model_b's transient files list is unaffected + assert len(model_b._transient_model_files) == 1 + + # Unload model_b + await model_b.unload() + + # Now model_b's transient file should be deleted + assert not os.path.exists(transient_file_b) + + async def test_custom_unload_can_call_super(self, model_settings, tmp_path): + """Test that custom unload can call super for cleanup""" + + class CustomModel(SimpleTestModel): + def __init__(self, settings): + super().__init__(settings) + self.custom_cleanup_called = False + + async def unload(self) -> bool: + # Custom cleanup + self.custom_cleanup_called = True + + # Call parent for transient file cleanup + return await super().unload() + + model = CustomModel(model_settings) + + # Create transient file + transient_file = tmp_path / "model.json" + transient_file.write_text('{"test": "data"}') + + model.register_transient_file(str(transient_file)) + + # Unload + await model.unload() + + # Both custom and base cleanup should have run + assert model.custom_cleanup_called + assert not os.path.exists(transient_file) + diff --git a/tests/test_utils.py b/tests/test_utils.py index 727002f48..77040a6cd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,11 @@ import asyncio import os import signal +import platform +import tempfile + from typing import Dict, Optional -from unittest.mock import patch +from unittest.mock import patch, MagicMock import pytest @@ -11,6 +14,7 @@ from mlserver.types import InferenceRequest, InferenceResponse, Parameters from mlserver.utils import ( get_model_uri, + ensure_local_path, extract_headers, insert_headers, AsyncManager, @@ -131,3 +135,310 @@ async def in_event_loop(): async_mgr = AsyncManager(loop_signal_handler_config) async_mgr.run(in_event_loop()) + + +class TestEnsureLocalPath: + """Tests for ensure_local_path() function""" + + def test_standard_file_no_copy(self, tmp_path): + """Test that standard files are returned as-is without copying""" + # Create a test file + test_file = tmp_path / "model.json" + test_file.write_text('{"test": "data"}') + + result = ensure_local_path(str(test_file)) + + # Should return original path + assert result == str(test_file) + # File should still exist at original location + assert test_file.exists() + + def test_nonexistent_file_returns_original(self): + """Test that non-existent files return original path""" + nonexistent = "/path/that/does/not/exist.json" + + result = ensure_local_path(nonexistent) + + assert result == nonexistent + + def test_symlink_detected_and_copied(self, tmp_path): + """Test that symlinks are detected and copied""" + # Create actual file + actual_file = tmp_path / "actual" / "model.json" + actual_file.parent.mkdir() + actual_file.write_text('{"test": "data"}') + + # Create symlink + link_file = tmp_path / "link" / "model.json" + link_file.parent.mkdir() + os.symlink(actual_file, link_file) + + result = ensure_local_path(str(link_file)) + + # Should return different path (temp location) + assert result != str(link_file) + assert result.startswith("/tmp/mlserver_") + assert result.endswith(".json") + + # Original symlink should still exist + assert link_file.exists() + + # Temp file should exist and have same content + assert os.path.exists(result) + with open(result) as f: + assert f.read() == '{"test": "data"}' + + # Cleanup + os.unlink(result) + + def test_bind_mount_simulation(self, tmp_path): + """Test detection of bind mount (path resolution difference)""" + # We can't create real bind mounts in tests, so we mock os.path.realpath + test_file = tmp_path / "model.json" + test_file.write_text('{"test": "data"}') + + with patch('os.path.realpath') as mock_realpath: + # Simulate bind mount: input is /mnt/models/model.json + # but realpath returns /models/model.json + mock_realpath.return_value = "/models/model.json" + + result = ensure_local_path(str(test_file)) + + # Should detect the difference and copy + assert result != str(test_file) + assert result.startswith("/tmp/mlserver_") + assert result.endswith(".json") + + # Cleanup + if os.path.exists(result): + os.unlink(result) + + def test_proc_path_detected(self, tmp_path): + """Test that /proc/ paths are detected and copied""" + test_file = tmp_path / "model.bst" + test_file.write_text("xgboost model data") + + with patch('os.path.realpath') as mock_realpath: + # Simulate proc-based mount + mock_realpath.return_value = "/proc/123/root/models/model.bst" + + result = ensure_local_path(str(test_file)) + + # Should detect /proc/ and copy + assert result != str(test_file) + assert result.startswith("/tmp/mlserver_") + assert result.endswith(".bst") + + # Cleanup + if os.path.exists(result): + os.unlink(result) + + def test_file_extension_preserved(self, tmp_path): + """Test that file extensions are preserved in temp files""" + extensions = [".json", ".bst", ".pkl", ".joblib", ".bin"] + + for ext in extensions: + test_file = tmp_path / f"model{ext}" + test_file.write_text("test data") + + with patch('os.path.realpath') as mock_realpath: + # Force copy by simulating bind mount + mock_realpath.return_value = "/different/path" + ext + + result = ensure_local_path(str(test_file)) + + # Extension should be preserved + assert result.endswith(ext) + + # Cleanup + if os.path.exists(result): + os.unlink(result) + + def test_file_content_preserved(self, tmp_path): + """Test that file content is correctly copied""" + test_file = tmp_path / "model.json" + original_content = '{"key": "value", "number": 42}' + test_file.write_text(original_content) + + with patch('os.path.realpath') as mock_realpath: + # Force copy + mock_realpath.return_value = "/different/path/model.json" + + result = ensure_local_path(str(test_file)) + + # Content should match + with open(result) as f: + assert f.read() == original_content + + # Cleanup + os.unlink(result) + + def test_absolute_vs_relative_paths(self, tmp_path): + """Test handling of absolute paths""" + test_file = tmp_path / "model.json" + test_file.write_text('{"test": "data"}') + + # Absolute path - should work normally + result = ensure_local_path(str(test_file.absolute())) + assert result == str(test_file.absolute()) + + def test_multimodel_scenario(self, tmp_path): + """Test multiple models with different paths""" + # Create multiple model files + model1 = tmp_path / "model1" / "model.json" + model1.parent.mkdir() + model1.write_text('{"model": 1}') + + model2 = tmp_path / "model2" / "model.bst" + model2.parent.mkdir() + model2.write_text("xgboost model 2") + + with patch('os.path.realpath') as mock_realpath: + # Simulate bind mounts for both + def realpath_side_effect(path): + if "model1" in path: + return "/models/model1/model.json" + elif "model2" in path: + return "/models/model2/model.bst" + return path + + mock_realpath.side_effect = realpath_side_effect + + result1 = ensure_local_path(str(model1)) + result2 = ensure_local_path(str(model2)) + + # Both should be copied to different temp files + assert result1 != result2 + assert result1.startswith("/tmp/mlserver_") + assert result2.startswith("/tmp/mlserver_") + assert result1.endswith(".json") + assert result2.endswith(".bst") + + # Both should have correct content + with open(result1) as f: + assert f.read() == '{"model": 1}' + with open(result2) as f: + assert f.read() == "xgboost model 2" + + # Cleanup + os.unlink(result1) + os.unlink(result2) + + def test_broken_symlink_detected(self, tmp_path): + """Test that broken symlinks are detected""" + # Create symlink to non-existent target + link_file = tmp_path / "broken_link.json" + os.symlink("/nonexistent/target.json", link_file) + + # This should be detected as broken symlink + # In real implementation, os.path.exists() will return False + # So it will return original path + result = ensure_local_path(str(link_file)) + + # Should return original (since file doesn't exist according to exists check) + assert result == str(link_file) + + +class TestTemporaryFileCleanup: + """Tests for temporary file cleanup functionality""" + + def test_temp_file_is_created_for_bind_mount(self, tmp_path): + """Test that temporary file is created when copying from bind mount""" + test_file = tmp_path / "model.json" + test_file.write_text('{"test": "data"}') + + with patch('os.path.realpath') as mock_realpath: + # Simulate bind mount + mock_realpath.return_value = "/different/path/model.json" + + result = ensure_local_path(str(test_file)) + + # Temp file should be created + assert result != str(test_file) + assert os.path.exists(result) + assert result.startswith("/tmp/mlserver_") + + # Cleanup + os.unlink(result) + + def test_temp_file_can_be_deleted(self, tmp_path): + """Test that temporary files can be deleted successfully""" + test_file = tmp_path / "model.bst" + test_file.write_text("xgboost data") + + with patch('os.path.realpath') as mock_realpath: + mock_realpath.return_value = "/proc/123/root/model.bst" + + temp_path = ensure_local_path(str(test_file)) + + # Verify temp file exists + assert os.path.exists(temp_path) + + # Delete it (simulating unload) + os.remove(temp_path) + + # Verify it's deleted + assert not os.path.exists(temp_path) + + def test_multiple_temp_files_can_coexist(self, tmp_path): + """Test that multiple models can have temp files simultaneously""" + temp_files = [] + + for i in range(3): + test_file = tmp_path / f"model{i}.json" + test_file.write_text(f'{{"model": {i}}}') + + with patch('os.path.realpath') as mock_realpath: + mock_realpath.return_value = f"/different/path/model{i}.json" + + temp_path = ensure_local_path(str(test_file)) + temp_files.append(temp_path) + + # Verify temp file exists + assert os.path.exists(temp_path) + + # All temp files should exist + assert len(set(temp_files)) == 3 # All unique + for temp_file in temp_files: + assert os.path.exists(temp_file) + + # Cleanup all + for temp_file in temp_files: + os.remove(temp_file) + assert not os.path.exists(temp_file) + + def test_deletion_of_nonexistent_file_handled_gracefully(self): + """Test that deleting non-existent file doesn't raise error""" + non_existent = "/tmp/mlserver_nonexistent_12345.json" + + # Should not raise error + try: + if os.path.exists(non_existent): + os.remove(non_existent) + except Exception as e: + pytest.fail(f"Deletion check raised unexpected exception: {e}") + + def test_temp_file_survives_until_explicitly_deleted(self, tmp_path): + """Test that temp files persist until explicitly deleted""" + test_file = tmp_path / "model.pkl" + test_file.write_text("sklearn model") + + with patch('os.path.realpath') as mock_realpath: + mock_realpath.return_value = "/bind/mount/model.pkl" + + temp_path = ensure_local_path(str(test_file)) + + # Verify temp file exists + assert os.path.exists(temp_path) + + # Simulate some operations... + with open(temp_path) as f: + content = f.read() + assert content == "sklearn model" + + # File should still exist + assert os.path.exists(temp_path) + + # Only deleted when explicitly removed + os.remove(temp_path) + assert not os.path.exists(temp_path)