|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import logging |
| 4 | +import time |
3 | 5 | from pathlib import Path |
4 | 6 | from typing import Any |
5 | 7 |
|
6 | 8 | from bitloops_embeddings.errors import BackendLoadError, InferenceError |
7 | | -from bitloops_embeddings.logging_utils import log_event |
| 9 | +from bitloops_embeddings.logging_utils import LOGGER_NAME, log_event |
| 10 | + |
| 11 | + |
| 12 | +MODEL_LOAD_RETRY_DELAYS_SECONDS = (5, 10, 20) |
8 | 13 |
|
9 | 14 |
|
10 | 15 | class SentenceTransformersBackend: |
@@ -56,19 +61,36 @@ def load(self) -> None: |
56 | 61 | upstream_model_id=self._upstream_model_id, |
57 | 62 | cache_dir=self._cache_dir, |
58 | 63 | ) |
59 | | - try: |
60 | | - self._model = SentenceTransformer( |
61 | | - self._upstream_model_id, |
62 | | - cache_folder=str(self._cache_dir), |
63 | | - device="cpu", |
64 | | - ) |
65 | | - detected_dimensions = self._model.get_sentence_embedding_dimension() |
66 | | - if detected_dimensions is not None: |
67 | | - self._dimensions = int(detected_dimensions) |
68 | | - except Exception as exc: |
69 | | - raise BackendLoadError( |
70 | | - f"Failed to load model '{self.model_id}' from '{self._upstream_model_id}'." |
71 | | - ) from exc |
| 64 | + max_attempts = len(MODEL_LOAD_RETRY_DELAYS_SECONDS) + 1 |
| 65 | + for attempt in range(1, max_attempts + 1): |
| 66 | + try: |
| 67 | + self._model = SentenceTransformer( |
| 68 | + self._upstream_model_id, |
| 69 | + cache_folder=str(self._cache_dir), |
| 70 | + device="cpu", |
| 71 | + ) |
| 72 | + detected_dimensions = self._model.get_sentence_embedding_dimension() |
| 73 | + if detected_dimensions is not None: |
| 74 | + self._dimensions = int(detected_dimensions) |
| 75 | + break |
| 76 | + except Exception as exc: |
| 77 | + self._model = None |
| 78 | + if attempt >= max_attempts or not _is_retryable_load_exception(exc): |
| 79 | + raise BackendLoadError( |
| 80 | + f"Failed to load model '{self.model_id}' from '{self._upstream_model_id}'." |
| 81 | + ) from exc |
| 82 | + |
| 83 | + delay_seconds = MODEL_LOAD_RETRY_DELAYS_SECONDS[attempt - 1] |
| 84 | + logging.getLogger(LOGGER_NAME).warning( |
| 85 | + "event=model_load_retry model_id=%s backend=%s attempt=%s max_attempts=%s delay_seconds=%s reason=%s", |
| 86 | + self.model_id, |
| 87 | + self.backend_name, |
| 88 | + attempt, |
| 89 | + max_attempts, |
| 90 | + delay_seconds, |
| 91 | + str(exc), |
| 92 | + ) |
| 93 | + time.sleep(delay_seconds) |
72 | 94 |
|
73 | 95 | log_event( |
74 | 96 | "model_load_complete", |
@@ -97,3 +119,23 @@ def embed(self, texts: list[str]) -> list[list[float]]: |
97 | 119 |
|
98 | 120 | def close(self) -> None: |
99 | 121 | self._model = None |
| 122 | + |
| 123 | + |
| 124 | +def _is_retryable_load_exception(exc: Exception) -> bool: |
| 125 | + message = str(exc).lower() |
| 126 | + retryable_markers = ( |
| 127 | + "http error 500", |
| 128 | + "http error 502", |
| 129 | + "http error 503", |
| 130 | + "http error 504", |
| 131 | + "connection error", |
| 132 | + "connection aborted", |
| 133 | + "connection reset", |
| 134 | + "read timed out", |
| 135 | + "timed out", |
| 136 | + "temporarily unavailable", |
| 137 | + "temporary failure", |
| 138 | + "service unavailable", |
| 139 | + "too many requests", |
| 140 | + ) |
| 141 | + return any(marker in message for marker in retryable_markers) |
0 commit comments