Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import random
import time
import traceback
import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
Expand All @@ -16,15 +15,15 @@
import requests
from openai import NOT_GIVEN, APITimeoutError, AsyncOpenAI
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage

from lagent.schema import ModelStatusCode
from lagent.utils import filter_suffix
from lagent.llms.base_api import AsyncBaseAPILLM, BaseAPILLM

warnings.simplefilter('default')

logger = getLogger(__name__)

OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'


Expand Down Expand Up @@ -846,15 +845,15 @@ def __init__(
self.extra_body = extra_body

async def chat(self, messages: list[dict], session_id: str | int = None, **kwargs) -> ChatCompletion:
fallback_response = ChatCompletion(
id=f'chatcmpl-{uuid.uuid4()}',
object='chat.completion',
created=int(time.time()),
model=self.model,
choices=[
Choice(message=ChatCompletionMessage(role='assistant', content=''), finish_reason='stop', index=0)
],
_RETRYABLE_STRINGS = (
"用户额度不足",
"剩余额度",
"TimeoutError",
"litellm.BadRequestError",
"litellm.APIError: APIError",
"Call `/v1/models` to view available models",
)
last_exc: Exception | None = None
for attempt in range(self.max_retry):
try:
client = random.choice(self.clients)
Expand All @@ -872,25 +871,19 @@ async def chat(self, messages: list[dict], session_id: str | int = None, **kwarg
)
return response
except (APITimeoutError, TimeoutError) as e:
print(f"LLM Call Timeout: {e}")
if attempt == self.max_retry - 1:
return fallback_response
await asyncio.sleep(self.sleep_interval)
last_exc = e
logger.warning("LLM call timeout (attempt %d/%d): %s", attempt + 1, self.max_retry, e)
if attempt < self.max_retry - 1:
await asyncio.sleep(self.sleep_interval)
except Exception as e:
for val in [
"用户额度不足",
"剩余额度",
"TimeoutError",
"litellm.BadRequestError",
"litellm.APIError: APIError",
"Call `/v1/models` to view available models",
]:
if val in str(e):
print(f"LLM Call Error: {e}")
if attempt == self.max_retry - 1:
return fallback_response
last_exc = e
err_str = str(e)
if any(val in err_str for val in _RETRYABLE_STRINGS):
logger.warning("LLM call error (attempt %d/%d): %s", attempt + 1, self.max_retry, e)
if attempt < self.max_retry - 1:
await asyncio.sleep(self.sleep_interval)
break
else:
return fallback_response
return fallback_response
raise
raise RuntimeError(
f"LLM call failed after {self.max_retry} retries"
) from last_exc