diff --git a/app/services/embedding.py b/app/services/embedding.py index b674270..3606f0c 100644 --- a/app/services/embedding.py +++ b/app/services/embedding.py @@ -1,4 +1,6 @@ """Embedding service — wraps LiteLLM for provider-agnostic vector generation.""" +import asyncio +import logging from __future__ import annotations @@ -6,6 +8,9 @@ # Default embedding model — small, fast, cheap DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" +# Retry configuration +MAX_RETRIES = 3 +RETRY_DELAY = 1.0 # seconds DEFAULT_EMBEDDING_DIMENSIONS = 1536 MAX_BATCH_SIZE = 128 @@ -36,9 +41,22 @@ async def embed_texts( batch = texts[i : i + MAX_BATCH_SIZE] kwargs: dict = {"model": model, "input": batch} if api_key: - kwargs["api_key"] = api_key - - response = await aembedding(**kwargs) + # Retry logic for API failures + for attempt in range(MAX_RETRIES): + try: + response = await aembedding(**kwargs) + batch_embeddings = [item["embedding"] for item in response.data] + all_embeddings.extend(batch_embeddings) + break # Success, exit retry loop + except Exception as e: + if attempt == MAX_RETRIES - 1: + # Final attempt failed, re-raise the exception + logging.error(f"Embedding API failed after {MAX_RETRIES} attempts: {e}") + raise RuntimeError(f"Failed to generate embeddings after {MAX_RETRIES} attempts: {e}") from e + else: + # Transient failure, wait and retry + logging.warning(f"Embedding API attempt {attempt + 1} failed: {e}. Retrying...") + await asyncio.sleep(RETRY_DELAY * (2 ** attempt)) # Exponential backoff batch_embeddings = [item["embedding"] for item in response.data] all_embeddings.extend(batch_embeddings)