diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 83c13b3..b183483 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -4,7 +4,6 @@ import random import time import traceback -import uuid import warnings from concurrent.futures import ThreadPoolExecutor from logging import getLogger @@ -16,8 +15,6 @@ import requests from openai import NOT_GIVEN, APITimeoutError, AsyncOpenAI from openai.types.chat import ChatCompletion -from openai.types.chat.chat_completion import Choice -from openai.types.chat.chat_completion_message import ChatCompletionMessage from lagent.schema import ModelStatusCode from lagent.utils import filter_suffix @@ -25,6 +22,8 @@ warnings.simplefilter('default') +logger = getLogger(__name__) + OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions' @@ -846,15 +845,15 @@ def __init__( self.extra_body = extra_body async def chat(self, messages: list[dict], session_id: str | int = None, **kwargs) -> ChatCompletion: - fallback_response = ChatCompletion( - id=f'chatcmpl-{uuid.uuid4()}', - object='chat.completion', - created=int(time.time()), - model=self.model, - choices=[ - Choice(message=ChatCompletionMessage(role='assistant', content=''), finish_reason='stop', index=0) - ], + _RETRYABLE_STRINGS = ( + "用户额度不足", + "剩余额度", + "TimeoutError", + "litellm.BadRequestError", + "litellm.APIError: APIError", + "Call `/v1/models` to view available models", ) + last_exc: Exception | None = None for attempt in range(self.max_retry): try: client = random.choice(self.clients) @@ -872,25 +871,19 @@ async def chat(self, messages: list[dict], session_id: str | int = None, **kwarg ) return response except (APITimeoutError, TimeoutError) as e: - print(f"LLM Call Timeout: {e}") - if attempt == self.max_retry - 1: - return fallback_response - await asyncio.sleep(self.sleep_interval) + last_exc = e + logger.warning("LLM call timeout (attempt %d/%d): %s", attempt + 1, self.max_retry, e) + if attempt < self.max_retry - 1: + await asyncio.sleep(self.sleep_interval) except Exception as e: - for val in [ - "用户额度不足", - "剩余额度", - "TimeoutError", - "litellm.BadRequestError", - "litellm.APIError: APIError", - "Call `/v1/models` to view available models", - ]: - if val in str(e): - print(f"LLM Call Error: {e}") - if attempt == self.max_retry - 1: - return fallback_response + last_exc = e + err_str = str(e) + if any(val in err_str for val in _RETRYABLE_STRINGS): + logger.warning("LLM call error (attempt %d/%d): %s", attempt + 1, self.max_retry, e) + if attempt < self.max_retry - 1: await asyncio.sleep(self.sleep_interval) - break else: - return fallback_response - return fallback_response + raise + raise RuntimeError( + f"LLM call failed after {self.max_retry} retries" + ) from last_exc