Skip to content

Commit 7615541

Browse files
committed
Retry transient model download failures
1 parent e7caf6b commit 7615541

3 files changed

Lines changed: 122 additions & 18 deletions

File tree

scripts/real_backend_smoke.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from urllib import error, request
1212

1313

14+
SMOKE_RETRY_DELAYS_SECONDS = (10, 20, 40)
15+
16+
1417
def main() -> None:
1518
parser = argparse.ArgumentParser(description="Run a real backend smoke test.")
1619
parser.add_argument(
@@ -21,11 +24,28 @@ def main() -> None:
2124
args = parser.parse_args()
2225

2326
binary = args.binary
24-
port = reserve_free_port()
27+
run_with_retries("embed smoke", lambda: run_embed_smoke(binary))
28+
run_with_retries("server smoke", lambda: run_server_smoke(binary, reserve_free_port()))
29+
run_with_retries("daemon smoke", lambda: run_daemon_smoke(binary))
30+
2531

26-
run_embed_smoke(binary)
27-
run_server_smoke(binary, port)
28-
run_daemon_smoke(binary)
32+
def run_with_retries(name: str, operation) -> None:
33+
total_attempts = len(SMOKE_RETRY_DELAYS_SECONDS) + 1
34+
for attempt in range(1, total_attempts + 1):
35+
try:
36+
operation()
37+
return
38+
except RuntimeError as exc:
39+
if attempt >= total_attempts:
40+
raise
41+
42+
delay_seconds = SMOKE_RETRY_DELAYS_SECONDS[attempt - 1]
43+
print(
44+
f"{name} failed on attempt {attempt}/{total_attempts}: {exc}\nRetrying in {delay_seconds}s...",
45+
file=sys.stderr,
46+
flush=True,
47+
)
48+
time.sleep(delay_seconds)
2949

3050

3151
def run_embed_smoke(binary: str) -> None:

src/bitloops_embeddings/backend/sentence_transformers_backend.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from __future__ import annotations
22

3+
import logging
4+
import time
35
from pathlib import Path
46
from typing import Any
57

68
from bitloops_embeddings.errors import BackendLoadError, InferenceError
7-
from bitloops_embeddings.logging_utils import log_event
9+
from bitloops_embeddings.logging_utils import LOGGER_NAME, log_event
10+
11+
12+
MODEL_LOAD_RETRY_DELAYS_SECONDS = (5, 10, 20)
813

914

1015
class SentenceTransformersBackend:
@@ -56,19 +61,36 @@ def load(self) -> None:
5661
upstream_model_id=self._upstream_model_id,
5762
cache_dir=self._cache_dir,
5863
)
59-
try:
60-
self._model = SentenceTransformer(
61-
self._upstream_model_id,
62-
cache_folder=str(self._cache_dir),
63-
device="cpu",
64-
)
65-
detected_dimensions = self._model.get_sentence_embedding_dimension()
66-
if detected_dimensions is not None:
67-
self._dimensions = int(detected_dimensions)
68-
except Exception as exc:
69-
raise BackendLoadError(
70-
f"Failed to load model '{self.model_id}' from '{self._upstream_model_id}'."
71-
) from exc
64+
max_attempts = len(MODEL_LOAD_RETRY_DELAYS_SECONDS) + 1
65+
for attempt in range(1, max_attempts + 1):
66+
try:
67+
self._model = SentenceTransformer(
68+
self._upstream_model_id,
69+
cache_folder=str(self._cache_dir),
70+
device="cpu",
71+
)
72+
detected_dimensions = self._model.get_sentence_embedding_dimension()
73+
if detected_dimensions is not None:
74+
self._dimensions = int(detected_dimensions)
75+
break
76+
except Exception as exc:
77+
self._model = None
78+
if attempt >= max_attempts or not _is_retryable_load_exception(exc):
79+
raise BackendLoadError(
80+
f"Failed to load model '{self.model_id}' from '{self._upstream_model_id}'."
81+
) from exc
82+
83+
delay_seconds = MODEL_LOAD_RETRY_DELAYS_SECONDS[attempt - 1]
84+
logging.getLogger(LOGGER_NAME).warning(
85+
"event=model_load_retry model_id=%s backend=%s attempt=%s max_attempts=%s delay_seconds=%s reason=%s",
86+
self.model_id,
87+
self.backend_name,
88+
attempt,
89+
max_attempts,
90+
delay_seconds,
91+
str(exc),
92+
)
93+
time.sleep(delay_seconds)
7294

7395
log_event(
7496
"model_load_complete",
@@ -97,3 +119,23 @@ def embed(self, texts: list[str]) -> list[list[float]]:
97119

98120
def close(self) -> None:
99121
self._model = None
122+
123+
124+
def _is_retryable_load_exception(exc: Exception) -> bool:
125+
message = str(exc).lower()
126+
retryable_markers = (
127+
"http error 500",
128+
"http error 502",
129+
"http error 503",
130+
"http error 504",
131+
"connection error",
132+
"connection aborted",
133+
"connection reset",
134+
"read timed out",
135+
"timed out",
136+
"temporarily unavailable",
137+
"temporary failure",
138+
"service unavailable",
139+
"too many requests",
140+
)
141+
return any(marker in message for marker in retryable_markers)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from types import ModuleType
5+
6+
from bitloops_embeddings.backend.sentence_transformers_backend import SentenceTransformersBackend
7+
8+
9+
class FakeSentenceTransformer:
10+
attempts = 0
11+
12+
def __init__(self, *args, **kwargs) -> None:
13+
type(self).attempts += 1
14+
if type(self).attempts < 3:
15+
raise RuntimeError("HTTP Error 503 thrown while requesting HEAD https://huggingface.co/BAAI/bge-m3/resolve/main/config.json")
16+
17+
def get_sentence_embedding_dimension(self) -> int:
18+
return 1024
19+
20+
21+
def test_sentence_transformers_backend_retries_transient_load_failures(
22+
monkeypatch,
23+
tmp_path,
24+
) -> None:
25+
fake_module = ModuleType("sentence_transformers")
26+
fake_module.SentenceTransformer = FakeSentenceTransformer
27+
monkeypatch.setitem(sys.modules, "sentence_transformers", fake_module)
28+
monkeypatch.setattr("bitloops_embeddings.backend.sentence_transformers_backend.time.sleep", lambda _: None)
29+
FakeSentenceTransformer.attempts = 0
30+
31+
backend = SentenceTransformersBackend(
32+
model_id="bge-m3",
33+
upstream_model_id="BAAI/bge-m3",
34+
cache_dir=tmp_path / "cache",
35+
dimensions=1024,
36+
)
37+
38+
backend.load()
39+
40+
assert backend.is_loaded is True
41+
assert backend.dimensions == 1024
42+
assert FakeSentenceTransformer.attempts == 3

0 commit comments

Comments
 (0)