From 625c2d5a77a65f94ae2f377023ba6b8968f3d351 Mon Sep 17 00:00:00 2001 From: Junyan Chin Date: Sun, 1 Mar 2026 18:11:43 +0800 Subject: [PATCH 1/7] feat: model fallback chain (#2017) --------- Co-authored-by: TyperBody --- .../libs/wechatpad_api/util/http_util.py | 28 +- .../controller/groups/provider/api_chains.py | 41 + src/langbot/pkg/api/http/service/api_chain.py | 149 ++++ src/langbot/pkg/api/http/service/space.py | 82 +- src/langbot/pkg/core/app.py | 6 + src/langbot/pkg/core/bootutils/deps.py | 8 +- src/langbot/pkg/core/stages/build_app.py | 9 + .../pkg/entity/persistence/api_chain.py | 98 +++ .../migrations/dbm008_api_chain.py | 46 ++ .../dbm020_api_chain_model_api_level.py | 39 + .../dbm021_api_chain_health_check_flag.py | 27 + .../cntfilter/filters/baiduexamine.py | 83 +- src/langbot/pkg/pipeline/preproc/preproc.py | 52 +- src/langbot/pkg/platform/sources/discord.py | 60 +- src/langbot/pkg/platform/sources/kook.py | 68 +- src/langbot/pkg/platform/sources/lark.py | 24 +- .../pkg/platform/sources/legacy/gewechat.py | 18 +- src/langbot/pkg/platform/sources/telegram.py | 46 +- src/langbot/pkg/platform/webhook_pusher.py | 34 +- .../pkg/provider/modelmgr/api_chain.py | 757 ++++++++++++++++++ .../pkg/provider/runners/localagent.py | 122 ++- src/langbot/pkg/provider/runners/n8nsvapi.py | 80 +- src/langbot/pkg/utils/image.py | 163 ++-- .../templates/metadata/pipeline/ai.yaml | 6 +- web/eslint.config.mjs | 3 + .../home/bots/components/bot-form/BotForm.tsx | 8 +- .../api-chains-dialog/APIChainCard.tsx | 495 ++++++++++++ .../api-chains-dialog/APIChainForm.tsx | 665 +++++++++++++++ .../api-chains-dialog/APIChainsDialog.tsx | 179 +++++ .../dynamic-form/DynamicFormComponent.tsx | 41 +- .../dynamic-form/DynamicFormItemComponent.tsx | 104 ++- .../components/models-dialog/ModelsDialog.tsx | 37 +- web/src/app/infra/entities/api/api_chain.ts | 64 ++ web/src/app/infra/entities/api/index.ts | 58 ++ web/src/app/infra/entities/form/dynamic.ts | 1 + web/src/app/infra/http/BackendClient.ts | 29 + web/src/i18n/locales/en-US.ts | 56 ++ web/src/i18n/locales/ja-JP.ts | 58 +- web/src/i18n/locales/zh-Hans.ts | 56 ++ web/src/i18n/locales/zh-Hant.ts | 55 ++ 40 files changed, 3487 insertions(+), 468 deletions(-) create mode 100644 src/langbot/pkg/api/http/controller/groups/provider/api_chains.py create mode 100644 src/langbot/pkg/api/http/service/api_chain.py create mode 100644 src/langbot/pkg/entity/persistence/api_chain.py create mode 100644 src/langbot/pkg/persistence/migrations/dbm008_api_chain.py create mode 100644 src/langbot/pkg/persistence/migrations/dbm020_api_chain_model_api_level.py create mode 100644 src/langbot/pkg/persistence/migrations/dbm021_api_chain_health_check_flag.py create mode 100644 src/langbot/pkg/provider/modelmgr/api_chain.py create mode 100644 web/src/app/home/components/api-chains-dialog/APIChainCard.tsx create mode 100644 web/src/app/home/components/api-chains-dialog/APIChainForm.tsx create mode 100644 web/src/app/home/components/api-chains-dialog/APIChainsDialog.tsx create mode 100644 web/src/app/infra/entities/api/api_chain.ts diff --git a/src/langbot/libs/wechatpad_api/util/http_util.py b/src/langbot/libs/wechatpad_api/util/http_util.py index 7390f43ec..447c29df0 100644 --- a/src/langbot/libs/wechatpad_api/util/http_util.py +++ b/src/langbot/libs/wechatpad_api/util/http_util.py @@ -1,5 +1,5 @@ import requests -from langbot.pkg.utils import httpclient +import aiohttp def post_json(base_url, token, data=None): @@ -63,16 +63,16 @@ async def async_request( """ headers = {'Content-Type': 'application/json'} url = f'{base_url}?key={token_key}' - session = httpclient.get_session() - async with session.request( - method=method, url=url, params=params, headers=headers, data=data, json=json - ) as response: - response.raise_for_status() # 如果状态码不是200,抛出异常 - result = await response.json() - # print(result) - return result - # if result.get('Code') == 200: - # - # return await result - # else: - # raise RuntimeError("请求失败",response.text) + async with aiohttp.ClientSession() as session: + async with session.request( + method=method, url=url, params=params, headers=headers, data=data, json=json + ) as response: + response.raise_for_status() # 如果状态码不是200,抛出异常 + result = await response.json() + # print(result) + return result + # if result.get('Code') == 200: + # + # return await result + # else: + # raise RuntimeError("请求失败",response.text) diff --git a/src/langbot/pkg/api/http/controller/groups/provider/api_chains.py b/src/langbot/pkg/api/http/controller/groups/provider/api_chains.py new file mode 100644 index 000000000..fd921f135 --- /dev/null +++ b/src/langbot/pkg/api/http/controller/groups/provider/api_chains.py @@ -0,0 +1,41 @@ +"""API Chain HTTP Controller""" + +import quart + +from ... import group + + +@group.group_class('api_chains', '/api/v1/provider/api-chains') +class APIChainRouterGroup(group.RouterGroup): + async def initialize(self) -> None: + @self.route('', methods=['GET', 'POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) + async def _() -> str: + if quart.request.method == 'GET': + chains = await self.ap.api_chain_service.get_api_chains() + return self.success(data={'chains': chains}) + elif quart.request.method == 'POST': + json_data = await quart.request.json + chain_uuid = await self.ap.api_chain_service.create_api_chain(json_data) + return self.success(data={'uuid': chain_uuid}) + + @self.route('/', methods=['GET', 'PUT', 'DELETE'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) + async def _(chain_uuid: str) -> str: + if quart.request.method == 'GET': + chain = await self.ap.api_chain_service.get_api_chain(chain_uuid) + + if chain is None: + return self.http_status(404, -1, 'API chain not found') + + return self.success(data={'chain': chain}) + elif quart.request.method == 'PUT': + json_data = await quart.request.json + await self.ap.api_chain_service.update_api_chain(chain_uuid, json_data) + return self.success() + elif quart.request.method == 'DELETE': + await self.ap.api_chain_service.delete_api_chain(chain_uuid) + return self.success() + + @self.route('//test', methods=['POST'], auth_type=group.AuthType.USER_TOKEN_OR_API_KEY) + async def _(chain_uuid: str) -> str: + result = await self.ap.api_chain_service.test_api_chain(chain_uuid) + return self.success(data=result) diff --git a/src/langbot/pkg/api/http/service/api_chain.py b/src/langbot/pkg/api/http/service/api_chain.py new file mode 100644 index 000000000..5d12196cb --- /dev/null +++ b/src/langbot/pkg/api/http/service/api_chain.py @@ -0,0 +1,149 @@ +"""API Chain Service - HTTP service for managing API chains""" + +from __future__ import annotations + +import uuid +from typing import Dict, Any, List +import sqlalchemy + +from ....core import app +from ....entity.persistence import api_chain as api_chain_entity +# NOTE: uuid and sqlalchemy are kept for the read methods; mutations delegate to api_chain_mgr + + +class APIChainService: + """Service for managing API chains""" + + ap: app.Application + + def __init__(self, ap: app.Application) -> None: + self.ap = ap + + async def get_api_chains(self) -> List[Dict[str, Any]]: + """Get all API chains with their statuses""" + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(api_chain_entity.APIChain)) + + chains = [] + for chain in result.all(): + # Get status for all providers in this chain + status_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(api_chain_entity.APIChainStatus).where( + api_chain_entity.APIChainStatus.chain_uuid == chain.uuid + ) + ) + + statuses = [] + for status in status_result.all(): + statuses.append( + { + 'provider_uuid': status.provider_uuid, + 'model_name': status.model_name, + 'api_key_index': status.api_key_index, + 'is_healthy': status.is_healthy, + 'failure_count': status.failure_count, + 'last_failure_time': status.last_failure_time.isoformat() if status.last_failure_time else None, + 'last_success_time': status.last_success_time.isoformat() if status.last_success_time else None, + 'last_health_check_time': status.last_health_check_time.isoformat() + if status.last_health_check_time + else None, + 'last_error_message': status.last_error_message, + 'health_check_last_failed': status.health_check_last_failed, + } + ) + + chains.append( + { + 'uuid': chain.uuid, + 'name': chain.name, + 'description': chain.description, + 'chain_config': chain.chain_config, + 'health_check_interval': chain.health_check_interval, + 'health_check_enabled': chain.health_check_enabled, + 'created_at': chain.created_at.isoformat() if chain.created_at else None, + 'updated_at': chain.updated_at.isoformat() if chain.updated_at else None, + 'statuses': statuses, + } + ) + + return chains + + async def get_api_chain(self, chain_uuid: str) -> Dict[str, Any] | None: + """Get a specific API chain""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(api_chain_entity.APIChain).where(api_chain_entity.APIChain.uuid == chain_uuid) + ) + + chain = result.first() + if not chain: + return None + + # Get status for all providers in the chain + status_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(api_chain_entity.APIChainStatus).where( + api_chain_entity.APIChainStatus.chain_uuid == chain_uuid + ) + ) + + statuses = [] + for status in status_result.all(): + statuses.append( + { + 'provider_uuid': status.provider_uuid, + 'model_name': status.model_name, + 'api_key_index': status.api_key_index, + 'is_healthy': status.is_healthy, + 'failure_count': status.failure_count, + 'last_failure_time': status.last_failure_time.isoformat() if status.last_failure_time else None, + 'last_success_time': status.last_success_time.isoformat() if status.last_success_time else None, + 'last_health_check_time': status.last_health_check_time.isoformat() + if status.last_health_check_time + else None, + 'last_error_message': status.last_error_message, + 'health_check_last_failed': status.health_check_last_failed, + } + ) + + return { + 'uuid': chain.uuid, + 'name': chain.name, + 'description': chain.description, + 'chain_config': chain.chain_config, + 'health_check_interval': chain.health_check_interval, + 'health_check_enabled': chain.health_check_enabled, + 'created_at': chain.created_at.isoformat() if chain.created_at else None, + 'updated_at': chain.updated_at.isoformat() if chain.updated_at else None, + 'statuses': statuses, + } + + async def create_api_chain(self, chain_data: Dict[str, Any]) -> str: + """Create a new API chain""" + chain_data = dict(chain_data) + chain_data.setdefault('uuid', str(uuid.uuid4())) + chain_data.setdefault('chain_config', []) + chain_data.setdefault('health_check_interval', 300) + chain_data.setdefault('health_check_enabled', True) + + # Delegate to manager so in-memory state and health-check task are created + await self.ap.api_chain_mgr.create_chain(chain_data) + return chain_data['uuid'] + + async def update_api_chain(self, chain_uuid: str, chain_data: Dict[str, Any]): + """Update an existing API chain""" + chain_data = dict(chain_data) + chain_data.pop('uuid', None) + chain_data.pop('created_at', None) + chain_data.pop('updated_at', None) + + # Delegate to manager so in-memory state and health-check task are refreshed + await self.ap.api_chain_mgr.update_chain(chain_uuid, chain_data) + + async def delete_api_chain(self, chain_uuid: str): + """Delete an API chain""" + # Delegate to manager which handles both DB deletion and memory/task cleanup + await self.ap.api_chain_mgr.delete_chain(chain_uuid) + + async def test_api_chain(self, chain_uuid: str) -> Dict[str, Any]: + """Test an API chain by making a simple request""" + # This would make a test request through the chain + # For now, just return success + return {'success': True, 'message': 'API chain test not yet implemented'} diff --git a/src/langbot/pkg/api/http/service/space.py b/src/langbot/pkg/api/http/service/space.py index c05e4896a..cd6948833 100644 --- a/src/langbot/pkg/api/http/service/space.py +++ b/src/langbot/pkg/api/http/service/space.py @@ -1,6 +1,6 @@ from __future__ import annotations -from langbot.pkg.utils import httpclient +import aiohttp import typing import datetime import time @@ -99,49 +99,49 @@ async def exchange_oauth_code(self, code: str) -> typing.Dict: space_config = self._get_space_config() space_url = space_config['url'] - session = httpclient.get_session() - async with session.post( - f'{space_url}/api/v1/accounts/oauth/token', - json={'code': code, 'instance_id': constants.instance_id}, - ) as response: - if response.status != 200: - raise ValueError(f'Failed to exchange OAuth code: {await response.text()}') - data = await response.json() - if data.get('code') != 0: - raise ValueError(f'Failed to exchange OAuth code: {data.get("msg")}') - return data.get('data', {}) + async with aiohttp.ClientSession() as session: + async with session.post( + f'{space_url}/api/v1/accounts/oauth/token', + json={'code': code, 'instance_id': constants.instance_id}, + ) as response: + if response.status != 200: + raise ValueError(f'Failed to exchange OAuth code: {await response.text()}') + data = await response.json() + if data.get('code') != 0: + raise ValueError(f'Failed to exchange OAuth code: {data.get("msg")}') + return data.get('data', {}) async def refresh_token(self, refresh_token: str) -> typing.Dict: """Refresh Space access token""" space_config = self._get_space_config() space_url = space_config['url'] - session = httpclient.get_session() - async with session.post( - f'{space_url}/api/v1/accounts/token/refresh', json={'refresh_token': refresh_token} - ) as response: - if response.status != 200: - raise ValueError(f'Failed to refresh token: {await response.text()}') - data = await response.json() - if data.get('code') != 0: - raise ValueError(f'Failed to refresh token: {data.get("msg")}') - return data.get('data', {}) + async with aiohttp.ClientSession() as session: + async with session.post( + f'{space_url}/api/v1/accounts/token/refresh', json={'refresh_token': refresh_token} + ) as response: + if response.status != 200: + raise ValueError(f'Failed to refresh token: {await response.text()}') + data = await response.json() + if data.get('code') != 0: + raise ValueError(f'Failed to refresh token: {data.get("msg")}') + return data.get('data', {}) async def get_user_info_raw(self, access_token: str) -> typing.Dict: """Get user info from Space using access token (no validation)""" space_config = self._get_space_config() space_url = space_config['url'] - session = httpclient.get_session() - async with session.get( - f'{space_url}/api/v1/accounts/me', headers={'Authorization': f'Bearer {access_token}'} - ) as response: - if response.status != 200: - raise ValueError(f'Failed to get user info: {await response.text()}') - data = await response.json() - if data.get('code') != 0: - raise ValueError(f'Failed to get user info: {data.get("msg")}') - return data.get('data', {}) + async with aiohttp.ClientSession() as session: + async with session.get( + f'{space_url}/api/v1/accounts/me', headers={'Authorization': f'Bearer {access_token}'} + ) as response: + if response.status != 200: + raise ValueError(f'Failed to get user info: {await response.text()}') + data = await response.json() + if data.get('code') != 0: + raise ValueError(f'Failed to get user info: {data.get("msg")}') + return data.get('data', {}) # === API calls with token validation === @@ -178,12 +178,12 @@ async def get_models(self) -> typing.List[SpaceModel]: space_config = self._get_space_config() space_url = space_config['url'] - session = httpclient.get_session() - async with session.get(f'{space_url}/api/v1/models') as response: - if response.status != 200: - raise ValueError(f'Failed to get models: {await response.text()}') - data = await response.json() - if data.get('code') != 0: - raise ValueError(f'Failed to get models: {data.get("msg")}') - models_data = data.get('data', {}).get('models', []) - return [SpaceModel.model_validate(model_dict) for model_dict in models_data] + async with aiohttp.ClientSession() as session: + async with session.get(f'{space_url}/api/v1/models') as response: + if response.status != 200: + raise ValueError(f'Failed to get models: {await response.text()}') + data = await response.json() + if data.get('code') != 0: + raise ValueError(f'Failed to get models: {data.get("msg")}') + models_data = data.get('data', {}).get('models', []) + return [SpaceModel.model_validate(model_dict) for model_dict in models_data] diff --git a/src/langbot/pkg/core/app.py b/src/langbot/pkg/core/app.py index 98e886175..a7928341f 100644 --- a/src/langbot/pkg/core/app.py +++ b/src/langbot/pkg/core/app.py @@ -9,6 +9,7 @@ from ..platform.webhook_pusher import WebhookPusher from ..provider.session import sessionmgr as llm_session_mgr from ..provider.modelmgr import modelmgr as llm_model_mgr +from ..provider.modelmgr import api_chain as api_chain_module from langbot.pkg.provider.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr from ..command import cmdmgr @@ -30,6 +31,7 @@ from ..api.http.service import apikey as apikey_service from ..api.http.service import webhook as webhook_service from ..api.http.service import monitoring as monitoring_service +from ..api.http.service import api_chain as api_chain_service from ..discover import engine as discover_engine from ..storage import mgr as storagemgr from ..utils import logcache @@ -62,6 +64,8 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None + api_chain_mgr: api_chain_module.APIChainManager = None + rag_mgr: rag_mgr.RAGManager = None rag_runtime_service: RAGRuntimeService = None @@ -151,6 +155,8 @@ class Application: monitoring_service: monitoring_service.MonitoringService = None + api_chain_service: api_chain_service.APIChainService = None + def __init__(self): pass diff --git a/src/langbot/pkg/core/bootutils/deps.py b/src/langbot/pkg/core/bootutils/deps.py index 1f6530379..b2508b22e 100644 --- a/src/langbot/pkg/core/bootutils/deps.py +++ b/src/langbot/pkg/core/bootutils/deps.py @@ -1,4 +1,3 @@ -import importlib.util import pip import os from ...utils import pkgmgr @@ -50,10 +49,9 @@ async def check_deps() -> list[str]: missing_deps = [] for dep in required_deps: - # Use find_spec instead of __import__ to avoid actually loading - # all modules into memory. find_spec only checks if the module - # can be found, without executing module-level code. - if importlib.util.find_spec(dep) is None: + try: + __import__(dep) + except ImportError: missing_deps.append(dep) return missing_deps diff --git a/src/langbot/pkg/core/stages/build_app.py b/src/langbot/pkg/core/stages/build_app.py index 62f0ae7b5..d20a1340b 100644 --- a/src/langbot/pkg/core/stages/build_app.py +++ b/src/langbot/pkg/core/stages/build_app.py @@ -10,6 +10,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr +from ...provider.modelmgr import api_chain as api_chain_module from ...provider.tools import toolmgr as llm_tool_mgr from ...rag.knowledge import kbmgr as rag_mgr from ...rag.service import RAGRuntimeService @@ -28,6 +29,7 @@ from ...api.http.service import apikey as apikey_service from ...api.http.service import webhook as webhook_service from ...api.http.service import monitoring as monitoring_service +from ...api.http.service import api_chain as api_chain_service from ...discover import engine as discover_engine from ...storage import mgr as storagemgr from ...utils import logcache @@ -121,6 +123,10 @@ async def run(self, ap: app.Application): ap.model_mgr = llm_model_mgr_inst await llm_model_mgr_inst.initialize() + api_chain_mgr_inst = api_chain_module.APIChainManager(ap) + ap.api_chain_mgr = api_chain_mgr_inst + await api_chain_mgr_inst.initialize() + llm_session_mgr_inst = llm_session_mgr.SessionManager(ap) await llm_session_mgr_inst.initialize() ap.sess_mgr = llm_session_mgr_inst @@ -164,6 +170,9 @@ async def run(self, ap: app.Application): monitoring_service_inst = monitoring_service.MonitoringService(ap) ap.monitoring_service = monitoring_service_inst + api_chain_service_inst = api_chain_service.APIChainService(ap) + ap.api_chain_service = api_chain_service_inst + async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None: await asyncio.sleep(3) await plugin_connector_inst.initialize() diff --git a/src/langbot/pkg/entity/persistence/api_chain.py b/src/langbot/pkg/entity/persistence/api_chain.py new file mode 100644 index 000000000..b40903008 --- /dev/null +++ b/src/langbot/pkg/entity/persistence/api_chain.py @@ -0,0 +1,98 @@ +import sqlalchemy +from sqlalchemy import JSON, Integer, String, DateTime, Boolean +from .base import Base + + +class APIChain(Base): + """API Chain - manages multiple API providers with priority and failover""" + + __tablename__ = 'api_chains' + + uuid = sqlalchemy.Column(String(255), primary_key=True, unique=True) + name = sqlalchemy.Column(String(255), nullable=False) + description = sqlalchemy.Column(String(512), nullable=True) + + # Chain configuration + chain_config = sqlalchemy.Column(JSON, nullable=False, default=list) + """ + List of API chain items: + [ + { + "provider_uuid": "xxx", + "priority": 1, // provider priority in the chain + "is_aggregated": false, + "max_retries": 3, + "timeout_ms": 30000, + "model_configs": [ // optional: per-model priority config + { + "model_name": "gpt-4o", // model name (as in LLMModel.name) + "priority": 1, // model priority within this provider + "api_key_indices": [ // optional: per-API-key priority + {"index": 0, "priority": 1}, + {"index": 1, "priority": 2} + ] + } + ] + }, + ... + ] + If model_configs is empty/absent, the chain uses the query's original model + with round-robin API key rotation. If api_key_indices is empty/absent for a + model config, round-robin rotation is used for that model. + """ + + # Health check configuration + health_check_interval = sqlalchemy.Column(Integer, nullable=False, default=300) + """Health check interval in seconds for failed APIs""" + + health_check_enabled = sqlalchemy.Column(Boolean, nullable=False, default=True) + """Whether to enable automatic health check for failed APIs""" + + # Metadata + created_at = sqlalchemy.Column(DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) + + +class APIChainStatus(Base): + """API Chain Status - tracks the health status of APIs in chains""" + + __tablename__ = 'api_chain_status' + + uuid = sqlalchemy.Column(String(255), primary_key=True, unique=True) + chain_uuid = sqlalchemy.Column(String(255), nullable=False, index=True) + provider_uuid = sqlalchemy.Column(String(255), nullable=False, index=True) + + # Granularity: model-level and API-key-level tracking + model_name = sqlalchemy.Column(String(255), nullable=True, index=True) + """Model name (from LLMModel.name); NULL means provider-level status""" + + api_key_index = sqlalchemy.Column(Integer, nullable=True) + """Index into the provider's api_keys list; NULL means all/round-robin""" + + # Status tracking + is_healthy = sqlalchemy.Column(Boolean, nullable=False, default=True) + failure_count = sqlalchemy.Column(Integer, nullable=False, default=0) + last_failure_time = sqlalchemy.Column(DateTime, nullable=True) + last_success_time = sqlalchemy.Column(DateTime, nullable=True) + last_health_check_time = sqlalchemy.Column(DateTime, nullable=True) + + # Error information + last_error_message = sqlalchemy.Column(String(1024), nullable=True) + + health_check_last_failed = sqlalchemy.Column(Boolean, nullable=False, default=False) + """True when the last health-check probe itself failed (not a normal request failure). + Is_healthy remains False while this is True. Does NOT increment failure_count.""" + + # Metadata + created_at = sqlalchemy.Column(DateTime, nullable=False, server_default=sqlalchemy.func.now()) + updated_at = sqlalchemy.Column( + DateTime, + nullable=False, + server_default=sqlalchemy.func.now(), + onupdate=sqlalchemy.func.now(), + ) diff --git a/src/langbot/pkg/persistence/migrations/dbm008_api_chain.py b/src/langbot/pkg/persistence/migrations/dbm008_api_chain.py new file mode 100644 index 000000000..931a36322 --- /dev/null +++ b/src/langbot/pkg/persistence/migrations/dbm008_api_chain.py @@ -0,0 +1,46 @@ +"""Database migration for API Chain feature""" + +from sqlalchemy import text + + +async def migrate(ap): + """Add API chain tables""" + + # Create api_chains table + await ap.persistence_mgr.execute_async( + text(""" + CREATE TABLE IF NOT EXISTS api_chains ( + uuid VARCHAR(255) PRIMARY KEY, + name VARCHAR(255) NOT NULL, + description VARCHAR(512), + chain_config JSON NOT NULL, + health_check_interval INTEGER NOT NULL DEFAULT 300, + health_check_enabled BOOLEAN NOT NULL DEFAULT 1, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """) + ) + + # Create api_chain_status table + await ap.persistence_mgr.execute_async( + text(""" + CREATE TABLE IF NOT EXISTS api_chain_status ( + uuid VARCHAR(255) PRIMARY KEY, + chain_uuid VARCHAR(255) NOT NULL, + provider_uuid VARCHAR(255) NOT NULL, + is_healthy BOOLEAN NOT NULL DEFAULT 1, + failure_count INTEGER NOT NULL DEFAULT 0, + last_failure_time DATETIME, + last_success_time DATETIME, + last_health_check_time DATETIME, + last_error_message VARCHAR(1024), + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + INDEX idx_chain_uuid (chain_uuid), + INDEX idx_provider_uuid (provider_uuid) + ) + """) + ) + + ap.logger.info('API Chain tables created successfully') diff --git a/src/langbot/pkg/persistence/migrations/dbm020_api_chain_model_api_level.py b/src/langbot/pkg/persistence/migrations/dbm020_api_chain_model_api_level.py new file mode 100644 index 000000000..861d7b840 --- /dev/null +++ b/src/langbot/pkg/persistence/migrations/dbm020_api_chain_model_api_level.py @@ -0,0 +1,39 @@ +import sqlalchemy +from .. import migration + + +@migration.migration_class(20) +class DBMigrateAPIChainModelAPILevel(migration.DBMigration): + """Add model_name and api_key_index columns to api_chain_status for per-model/api-key health tracking""" + + async def upgrade(self): + """Upgrade""" + try: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE api_chain_status ADD COLUMN model_name VARCHAR(255) DEFAULT NULL') + ) + except Exception: + pass + + try: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE api_chain_status ADD COLUMN api_key_index INTEGER DEFAULT NULL') + ) + except Exception: + pass + + async def downgrade(self): + """Downgrade""" + try: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE api_chain_status DROP COLUMN model_name') + ) + except Exception: + pass + + try: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE api_chain_status DROP COLUMN api_key_index') + ) + except Exception: + pass diff --git a/src/langbot/pkg/persistence/migrations/dbm021_api_chain_health_check_flag.py b/src/langbot/pkg/persistence/migrations/dbm021_api_chain_health_check_flag.py new file mode 100644 index 000000000..d4d9a1c87 --- /dev/null +++ b/src/langbot/pkg/persistence/migrations/dbm021_api_chain_health_check_flag.py @@ -0,0 +1,27 @@ +import sqlalchemy +from .. import migration + + +@migration.migration_class(21) +class DBMigrateAPIChainHealthCheckFlag(migration.DBMigration): + """Add health_check_last_failed column to api_chain_status""" + + async def upgrade(self): + """Upgrade""" + try: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text( + 'ALTER TABLE api_chain_status ADD COLUMN health_check_last_failed BOOLEAN NOT NULL DEFAULT 0' + ) + ) + except Exception: + pass + + async def downgrade(self): + """Downgrade""" + try: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.text('ALTER TABLE api_chain_status DROP COLUMN health_check_last_failed') + ) + except Exception: + pass diff --git a/src/langbot/pkg/pipeline/cntfilter/filters/baiduexamine.py b/src/langbot/pkg/pipeline/cntfilter/filters/baiduexamine.py index a376310f6..4213e662b 100644 --- a/src/langbot/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/src/langbot/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -1,9 +1,10 @@ from __future__ import annotations +import aiohttp + from .. import entities from .. import filter as filter_model import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query -from langbot.pkg.utils import httpclient BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}' BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' @@ -14,50 +15,50 @@ class BaiduCloudExamine(filter_model.ContentFilter): """百度云内容审核""" async def _get_token(self) -> str: - session = httpclient.get_session() - async with session.post( - BAIDU_EXAMINE_TOKEN_URL, - params={ - 'grant_type': 'client_credentials', - 'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'], - 'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'], - }, - ) as resp: - return (await resp.json())['access_token'] + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_TOKEN_URL, + params={ + 'grant_type': 'client_credentials', + 'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'], + 'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'], + }, + ) as resp: + return (await resp.json())['access_token'] async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult: - session = httpclient.get_session() - async with session.post( - BAIDU_EXAMINE_URL.format(await self._get_token()), - headers={ - 'Content-Type': 'application/x-www-form-urlencoded', - 'Accept': 'application/json', - }, - data=f'text={message}'.encode('utf-8'), - ) as resp: - result = await resp.json() - - if 'error_code' in result: - return entities.FilterResult( - level=entities.ResultLevel.BLOCK, - replacement=message, - user_notice='', - console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}', - ) - else: - conclusion = result['conclusion'] - - if conclusion in ('合规'): + async with aiohttp.ClientSession() as session: + async with session.post( + BAIDU_EXAMINE_URL.format(await self._get_token()), + headers={ + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json', + }, + data=f'text={message}'.encode('utf-8'), + ) as resp: + result = await resp.json() + + if 'error_code' in result: return entities.FilterResult( - level=entities.ResultLevel.PASS, + level=entities.ResultLevel.BLOCK, replacement=message, user_notice='', - console_notice=f'百度云判定结果:{conclusion}', + console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}', ) else: - return entities.FilterResult( - level=entities.ResultLevel.BLOCK, - replacement=message, - user_notice='消息中存在不合适的内容, 请修改', - console_notice=f'百度云判定结果:{conclusion}', - ) + conclusion = result['conclusion'] + + if conclusion in ('合规'): + return entities.FilterResult( + level=entities.ResultLevel.PASS, + replacement=message, + user_notice='', + console_notice=f'百度云判定结果:{conclusion}', + ) + else: + return entities.FilterResult( + level=entities.ResultLevel.BLOCK, + replacement=message, + user_notice='消息中存在不合适的内容, 请修改', + console_notice=f'百度云判定结果:{conclusion}', + ) diff --git a/src/langbot/pkg/pipeline/preproc/preproc.py b/src/langbot/pkg/pipeline/preproc/preproc.py index cd039d796..9cf4d8b93 100644 --- a/src/langbot/pkg/pipeline/preproc/preproc.py +++ b/src/langbot/pkg/pipeline/preproc/preproc.py @@ -36,17 +36,20 @@ async def process( session = await self.ap.sess_mgr.get_session(query) # When not local-agent, llm_model is None - try: - llm_model = ( - await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model']) - if selected_runner == 'local-agent' - else None - ) - except ValueError: - self.ap.logger.warning( - f'LLM model {query.pipeline_config["ai"]["local-agent"]["model"] + " "}not found or not configured' - ) - llm_model = None + llm_model = None + use_api_chain_uuid = None + if selected_runner == 'local-agent': + model_value = query.pipeline_config['ai']['local-agent'].get('model', '') + if model_value: + try: + llm_model = await self.ap.model_mgr.get_model_by_uuid(model_value) + except ValueError: + # Not a model UUID — try as API chain UUID + chain = await self.ap.api_chain_mgr.get_chain(model_value) + if chain: + use_api_chain_uuid = model_value + else: + self.ap.logger.warning(f'LLM model/chain {model_value} not found or not configured') conversation = await self.ap.sess_mgr.get_conversation( query, @@ -61,19 +64,28 @@ async def process( query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() - if selected_runner == 'local-agent' and llm_model: + if selected_runner == 'local-agent': query.use_funcs = [] - query.use_llm_model_uuid = llm_model.model_entity.uuid - - if llm_model.model_entity.abilities.__contains__('func_call'): - # Get bound plugins and MCP servers for filtering tools + if llm_model: + query.use_llm_model_uuid = llm_model.model_entity.uuid + + if llm_model.model_entity.abilities.__contains__('func_call'): + # Get bound plugins and MCP servers for filtering tools + bound_plugins = query.variables.get('_pipeline_bound_plugins', None) + bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None) + query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers) + + self.ap.logger.debug(f'Bound plugins: {bound_plugins}') + self.ap.logger.debug(f'Bound MCP servers: {bound_mcp_servers}') + self.ap.logger.debug(f'Use funcs: {query.use_funcs}') + + elif use_api_chain_uuid: + query.variables['_use_api_chain_uuid'] = use_api_chain_uuid + # Enable all tools for chain; individual models will decide capability bound_plugins = query.variables.get('_pipeline_bound_plugins', None) bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None) query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers) - - self.ap.logger.debug(f'Bound plugins: {bound_plugins}') - self.ap.logger.debug(f'Bound MCP servers: {bound_mcp_servers}') - self.ap.logger.debug(f'Use funcs: {query.use_funcs}') + self.ap.logger.debug(f'Using API chain {use_api_chain_uuid} for local-agent') sender_name = '' diff --git a/src/langbot/pkg/platform/sources/discord.py b/src/langbot/pkg/platform/sources/discord.py index e9cc7a37e..cb80ce48e 100644 --- a/src/langbot/pkg/platform/sources/discord.py +++ b/src/langbot/pkg/platform/sources/discord.py @@ -14,7 +14,7 @@ import asyncio from enum import Enum -from langbot.pkg.utils import httpclient +import aiohttp import pydantic import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter @@ -622,23 +622,23 @@ async def yiri2target( image_bytes = base64.b64decode(base64_data) elif ele.url: # 从URL下载图片 - session = httpclient.get_session() - async with session.get(ele.url) as response: - image_bytes = await response.read() - # 从URL或Content-Type推断文件类型 - content_type = response.headers.get('Content-Type', '') - if 'jpeg' in content_type or 'jpg' in content_type: - filename = f'{uuid.uuid4()}.jpg' - elif 'gif' in content_type: - filename = f'{uuid.uuid4()}.gif' - elif 'webp' in content_type: - filename = f'{uuid.uuid4()}.webp' - elif ele.url.lower().endswith(('.jpg', '.jpeg')): - filename = f'{uuid.uuid4()}.jpg' - elif ele.url.lower().endswith('.gif'): - filename = f'{uuid.uuid4()}.gif' - elif ele.url.lower().endswith('.webp'): - filename = f'{uuid.uuid4()}.webp' + async with aiohttp.ClientSession() as session: + async with session.get(ele.url) as response: + image_bytes = await response.read() + # 从URL或Content-Type推断文件类型 + content_type = response.headers.get('Content-Type', '') + if 'jpeg' in content_type or 'jpg' in content_type: + filename = f'{uuid.uuid4()}.jpg' + elif 'gif' in content_type: + filename = f'{uuid.uuid4()}.gif' + elif 'webp' in content_type: + filename = f'{uuid.uuid4()}.webp' + elif ele.url.lower().endswith(('.jpg', '.jpeg')): + filename = f'{uuid.uuid4()}.jpg' + elif ele.url.lower().endswith('.gif'): + filename = f'{uuid.uuid4()}.gif' + elif ele.url.lower().endswith('.webp'): + filename = f'{uuid.uuid4()}.webp' elif ele.path: # 从文件路径读取图片 # 确保路径没有空字节 @@ -702,9 +702,9 @@ async def yiri2target( file_base64 = ele.base64.split(',')[-1] file_bytes = base64.b64decode(file_base64) elif ele.url: - session = httpclient.get_session() - async with session.get(ele.url) as response: - file_bytes = await response.read() + async with aiohttp.ClientSession() as session: + async with session.get(ele.url) as response: + file_bytes = await response.read() if file_bytes: files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename)) elif isinstance(ele, platform_message.File): @@ -717,9 +717,9 @@ async def yiri2target( else: file_bytes = base64.b64decode(ele.base64) elif ele.url: - session = httpclient.get_session() - async with session.get(ele.url) as response: - file_bytes = await response.read() + async with aiohttp.ClientSession() as session: + async with session.get(ele.url) as response: + file_bytes = await response.read() if file_bytes: files.append(discord.File(fp=io.BytesIO(file_bytes), filename=filename)) elif isinstance(ele, platform_message.Forward): @@ -775,12 +775,12 @@ def text_element_recur( # attachments for attachment in message.attachments: - session = httpclient.get_session(trust_env=True) - async with session.get(attachment.url) as response: - image_data = await response.read() - image_base64 = base64.b64encode(image_data).decode('utf-8') - image_format = response.headers['Content-Type'] - element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}')) + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get(attachment.url) as response: + image_data = await response.read() + image_base64 = base64.b64encode(image_data).decode('utf-8') + image_format = response.headers['Content-Type'] + element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}')) return platform_message.MessageChain(element_list) diff --git a/src/langbot/pkg/platform/sources/kook.py b/src/langbot/pkg/platform/sources/kook.py index 5a6bade36..17777a95e 100644 --- a/src/langbot/pkg/platform/sources/kook.py +++ b/src/langbot/pkg/platform/sources/kook.py @@ -9,8 +9,6 @@ import time import aiohttp - -from langbot.pkg.utils import httpclient import websockets import pydantic @@ -122,16 +120,16 @@ async def target2yiri(kook_message: dict, bot_account_id: str = '') -> platform_ if content: # Download image and convert to base64 try: - session = httpclient.get_session() - async with session.get(content) as response: - if response.status == 200: - image_bytes = await response.read() - image_base64 = base64.b64encode(image_bytes).decode('utf-8') - # Detect image format - content_type = response.headers.get('Content-Type', 'image/png') - components.append( - platform_message.Image(base64=f'data:{content_type};base64,{image_base64}') - ) + async with aiohttp.ClientSession() as session: + async with session.get(content) as response: + if response.status == 200: + image_bytes = await response.read() + image_base64 = base64.b64encode(image_bytes).decode('utf-8') + # Detect image format + content_type = response.headers.get('Content-Type', 'image/png') + components.append( + platform_message.Image(base64=f'data:{content_type};base64,{image_base64}') + ) except Exception: # If download fails, just add as plain text components.append(platform_message.Plain(text=f'[Image: {content}]')) @@ -297,17 +295,17 @@ async def _get_gateway_url(self) -> str: 'Authorization': f'Bot {self.config["token"]}', } - session = httpclient.get_session() - async with session.get(base_url, params=params, headers=headers) as response: - if response.status == 200: - data = await response.json() - if data.get('code') == 0: - gateway_url = data['data']['url'] - return gateway_url + async with aiohttp.ClientSession() as session: + async with session.get(base_url, params=params, headers=headers) as response: + if response.status == 200: + data = await response.json() + if data.get('code') == 0: + gateway_url = data['data']['url'] + return gateway_url + else: + raise Exception(f'Failed to get gateway URL: {data.get("message")}') else: - raise Exception(f'Failed to get gateway URL: {data.get("message")}') - else: - raise Exception(f'Failed to get gateway URL: HTTP {response.status}') + raise Exception(f'Failed to get gateway URL: HTTP {response.status}') async def _get_bot_user_info(self) -> dict: """Get bot's own user information from KOOK API""" @@ -317,17 +315,17 @@ async def _get_bot_user_info(self) -> dict: 'Authorization': f'Bot {self.config["token"]}', } - session = httpclient.get_session() - async with session.get(base_url, headers=headers) as response: - if response.status == 200: - data = await response.json() - if data.get('code') == 0: - user_info = data['data'] - return user_info + async with aiohttp.ClientSession() as session: + async with session.get(base_url, headers=headers) as response: + if response.status == 200: + data = await response.json() + if data.get('code') == 0: + user_info = data['data'] + return user_info + else: + raise Exception(f'Failed to get bot user info: {data.get("message")}') else: - raise Exception(f'Failed to get bot user info: {data.get("message")}') - else: - raise Exception(f'Failed to get bot user info: HTTP {response.status}') + raise Exception(f'Failed to get bot user info: HTTP {response.status}') async def _handle_hello(self, data: dict): """Handle HELLO signal (signal 1)""" @@ -512,7 +510,7 @@ async def send_message(self, target_type: str, target_id: str, message: platform try: if not self.http_session: - self.http_session = httpclient.get_session() + self.http_session = aiohttp.ClientSession() async with self.http_session.post(url, json=payload, headers=headers) as response: if response.status == 200: @@ -578,7 +576,7 @@ async def reply_message( try: if not self.http_session: - self.http_session = httpclient.get_session() + self.http_session = aiohttp.ClientSession() async with self.http_session.post(url, json=payload, headers=headers) as response: if response.status == 200: @@ -626,7 +624,7 @@ async def run_async(self): try: # Create HTTP session - self.http_session = httpclient.get_session() + self.http_session = aiohttp.ClientSession() await self.logger.info('Starting KOOK adapter') diff --git a/src/langbot/pkg/platform/sources/lark.py b/src/langbot/pkg/platform/sources/lark.py index 3ce4280cd..ce5277311 100644 --- a/src/langbot/pkg/platform/sources/lark.py +++ b/src/langbot/pkg/platform/sources/lark.py @@ -17,7 +17,7 @@ import os import mimetypes -from langbot.pkg.utils import httpclient +import aiohttp import lark_oapi.ws.exception import quart from lark_oapi.api.im.v1 import * @@ -78,13 +78,13 @@ async def upload_image_to_lark(msg: platform_message.Image, api_client: lark_oap return None elif msg.url: try: - session = httpclient.get_session() - async with session.get(msg.url) as response: - if response.status == 200: - image_bytes = await response.read() - else: - print(f'Failed to download image from {msg.url}: HTTP {response.status}') - return None + async with aiohttp.ClientSession() as session: + async with session.get(msg.url) as response: + if response.status == 200: + image_bytes = await response.read() + else: + print(f'Failed to download image from {msg.url}: HTTP {response.status}') + return None except Exception as e: print(f'Failed to download image from {msg.url}: {e}') traceback.print_exc() @@ -208,10 +208,10 @@ async def _get_media_bytes( pass elif msg.url: try: - session = httpclient.get_session() - async with session.get(msg.url) as resp: - if resp.status == 200: - data = await resp.read() + async with aiohttp.ClientSession() as session: + async with session.get(msg.url) as resp: + if resp.status == 200: + data = await resp.read() except Exception: pass elif msg.path: diff --git a/src/langbot/pkg/platform/sources/legacy/gewechat.py b/src/langbot/pkg/platform/sources/legacy/gewechat.py index 68e1bdedd..93bef53cb 100644 --- a/src/langbot/pkg/platform/sources/legacy/gewechat.py +++ b/src/langbot/pkg/platform/sources/legacy/gewechat.py @@ -9,7 +9,7 @@ import threading import quart -from langbot.pkg.utils import httpclient +import aiohttp import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter from ....core import app @@ -639,14 +639,14 @@ def unregister_listener( async def run_async(self): if not self.config['token']: - session = httpclient.get_session() - async with session.post( - f'{self.config["gewechat_url"]}/v2/api/tools/getTokenId', - json={'app_id': self.config['app_id']}, - ) as response: - if response.status != 200: - raise Exception(f'获取gewechat token失败: {await response.text()}') - self.config['token'] = (await response.json())['data'] + async with aiohttp.ClientSession() as session: + async with session.post( + f'{self.config["gewechat_url"]}/v2/api/tools/getTokenId', + json={'app_id': self.config['app_id']}, + ) as response: + if response.status != 200: + raise Exception(f'获取gewechat token失败: {await response.text()}') + self.config['token'] = (await response.json())['data'] self.bot = gewechat_client.GewechatClient(f'{self.config["gewechat_url"]}/v2/api', self.config['token']) diff --git a/src/langbot/pkg/platform/sources/telegram.py b/src/langbot/pkg/platform/sources/telegram.py index d43b9333c..c2b2fd032 100644 --- a/src/langbot/pkg/platform/sources/telegram.py +++ b/src/langbot/pkg/platform/sources/telegram.py @@ -10,9 +10,9 @@ import typing import traceback import base64 +import aiohttp import pydantic -from langbot.pkg.utils import httpclient import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter import langbot_plugin.api.entities.builtin.platform.message as platform_message import langbot_plugin.api.entities.builtin.platform.events as platform_events @@ -34,9 +34,9 @@ async def yiri2target(message_chain: platform_message.MessageChain, bot: telegra if component.base64: photo_bytes = base64.b64decode(component.base64) elif component.url: - session = httpclient.get_session() - async with session.get(component.url) as response: - photo_bytes = await response.read() + async with aiohttp.ClientSession() as session: + async with session.get(component.url) as response: + photo_bytes = await response.read() elif component.path: with open(component.path, 'rb') as f: photo_bytes = f.read() @@ -75,9 +75,10 @@ def parse_message_text(text: str) -> list[platform_message.MessageComponent]: file_bytes = None file_format = '' - async with httpclient.get_session(trust_env=True).get(file.file_path) as response: - file_bytes = await response.read() - file_format = 'image/jpeg' + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get(file.file_path) as response: + file_bytes = await response.read() + file_format = 'image/jpeg' message_components.append( platform_message.Image( @@ -94,8 +95,9 @@ def parse_message_text(text: str) -> list[platform_message.MessageComponent]: file_bytes = None file_format = message.voice.mime_type or 'audio/ogg' - async with httpclient.get_session(trust_env=True).get(file.file_path) as response: - file_bytes = await response.read() + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get(file.file_path) as response: + file_bytes = await response.read() message_components.append( platform_message.Voice( @@ -193,31 +195,7 @@ async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): ) async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): - components = await TelegramMessageConverter.yiri2target(message, self.bot) - - chat_id_str, _, thread_id_str = str(target_id).partition('#') - chat_id: int | str = int(chat_id_str) if chat_id_str.lstrip('-').isdigit() else chat_id_str - message_thread_id = int(thread_id_str) if thread_id_str and thread_id_str.isdigit() else None - - for component in components: - component_type = component.get('type') - args = {'chat_id': chat_id} - if message_thread_id is not None: - args['message_thread_id'] = message_thread_id - - if component_type == 'text': - text = component.get('text', '') - if self.config['markdown_card'] is True: - text = telegramify_markdown.markdownify(content=text) - args['parse_mode'] = 'MarkdownV2' - args['text'] = text - await self.bot.send_message(**args) - elif component_type == 'photo': - photo = component.get('photo') - if photo is None: - continue - args['photo'] = telegram.InputFile(photo) - await self.bot.send_photo(**args) + pass async def reply_message( self, diff --git a/src/langbot/pkg/platform/webhook_pusher.py b/src/langbot/pkg/platform/webhook_pusher.py index f3cf39b27..5a8d25644 100644 --- a/src/langbot/pkg/platform/webhook_pusher.py +++ b/src/langbot/pkg/platform/webhook_pusher.py @@ -3,8 +3,6 @@ import asyncio import logging import aiohttp - -from langbot.pkg.utils import httpclient import uuid from typing import TYPE_CHECKING @@ -121,23 +119,23 @@ async def _push_to_webhook(self, url: str, payload: dict) -> dict | None: dict | None: The response JSON if successful, None otherwise """ try: - session = httpclient.get_session() - async with session.post( - url, - json=payload, - headers={'Content-Type': 'application/json'}, - timeout=aiohttp.ClientTimeout(total=15), - ) as response: - if response.status >= 400: - self.logger.warning(f'Webhook {url} returned status {response.status}') - return None - else: - self.logger.debug(f'Successfully pushed to webhook {url}') - try: - return await response.json() - except Exception as json_error: - self.logger.debug(f'Failed to parse JSON response from webhook {url}: {json_error}') + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=payload, + headers={'Content-Type': 'application/json'}, + timeout=aiohttp.ClientTimeout(total=15), + ) as response: + if response.status >= 400: + self.logger.warning(f'Webhook {url} returned status {response.status}') return None + else: + self.logger.debug(f'Successfully pushed to webhook {url}') + try: + return await response.json() + except Exception as json_error: + self.logger.debug(f'Failed to parse JSON response from webhook {url}: {json_error}') + return None except asyncio.TimeoutError: self.logger.warning(f'Timeout pushing to webhook {url}') return None diff --git a/src/langbot/pkg/provider/modelmgr/api_chain.py b/src/langbot/pkg/provider/modelmgr/api_chain.py new file mode 100644 index 000000000..dc9316c77 --- /dev/null +++ b/src/langbot/pkg/provider/modelmgr/api_chain.py @@ -0,0 +1,757 @@ +"""API Chain Manager - handles API failover and health checking. + +chain_config item schema (per provider entry): +{ + "provider_uuid": "xxx", + "priority": 1, # provider priority in the chain (lower = higher priority) + "is_aggregated": false, + "max_retries": 3, + "timeout_ms": 30000, + "model_configs": [ # optional: per-model configuration + { + "model_name": "gpt-4o", # model name as stored in LLMModel.name + "priority": 1, # priority within this provider + "api_key_indices": [ # optional: per-API-key priority + {"index": 0, "priority": 1}, + {"index": 1, "priority": 2} + ] + } + ] +} +If model_configs is absent, the original query model is used with round-robin keys. +If api_key_indices is absent for a model config, round-robin rotation is used. +""" + +from __future__ import annotations + +import asyncio +import uuid as uuid_lib +from datetime import datetime +from typing import List, Dict, Any, Optional, Tuple, AsyncGenerator + +import sqlalchemy +from ...core import app +from ...entity.persistence import api_chain as api_chain_entity +from . import requester +from . import token + +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query +import langbot_plugin.api.entities.builtin.provider.message as provider_message +import langbot_plugin.api.entities.builtin.resource.tool as resource_tool + + +class APIChainManager: + """Manages API chains with per-model/per-API-key failover and health checking""" + + def __init__(self, ap: app.Application): + self.ap = ap + self.chains: Dict[str, api_chain_entity.APIChain] = {} + self.health_check_tasks: Dict[str, asyncio.Task] = {} + + async def initialize(self): + """Initialize API chain manager""" + await self.load_chains_from_db() + await self.start_health_check_tasks() + + async def load_chains_from_db(self): + """Load all API chains from database""" + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(api_chain_entity.APIChain)) + for row in result.all(): + # result.all() returns read-only Row objects; wrap them in mutable instances + chain = api_chain_entity.APIChain( + uuid=row.uuid, + name=row.name, + description=row.description, + chain_config=row.chain_config, + health_check_interval=row.health_check_interval, + health_check_enabled=row.health_check_enabled, + ) + self.chains[chain.uuid] = chain + + async def start_health_check_tasks(self): + """Start background health check tasks for all chains""" + for chain_uuid, chain in self.chains.items(): + if chain.health_check_enabled: + task = asyncio.create_task(self._health_check_loop(chain_uuid)) + self.health_check_tasks[chain_uuid] = task + + async def stop_health_check_tasks(self): + """Stop all health check tasks""" + for task in self.health_check_tasks.values(): + task.cancel() + self.health_check_tasks.clear() + + # ==================== Health Check ==================== + + async def _health_check_loop(self, chain_uuid: str): + """Background loop for health checking failed APIs. + + An immediate check is performed on startup so that pre-existing + unhealthy records are evaluated without waiting for the full interval. + """ + # Immediate check on start + try: + await self._perform_health_checks(chain_uuid) + except asyncio.CancelledError: + return + except Exception as e: + self.ap.logger.error(f'Initial health check error for chain {chain_uuid}: {e}') + + while True: + try: + chain = self.chains.get(chain_uuid) + if not chain or not chain.health_check_enabled: + break + await asyncio.sleep(chain.health_check_interval) + await self._perform_health_checks(chain_uuid) + except asyncio.CancelledError: + break + except Exception as e: + self.ap.logger.error(f'Health check loop error for chain {chain_uuid}: {e}') + await asyncio.sleep(60) + + async def _perform_health_checks(self, chain_uuid: str): + """Perform health checks on all unhealthy status records in a chain""" + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(api_chain_entity.APIChainStatus).where( + sqlalchemy.and_( + api_chain_entity.APIChainStatus.chain_uuid == chain_uuid, + api_chain_entity.APIChainStatus.is_healthy == False, + ) + ) + ) + for status in result.all(): + try: + provider = self.ap.model_mgr.provider_dict.get(status.provider_uuid) + if not provider: + continue + + is_healthy = await self._check_api_health(status, provider) + + if is_healthy: + await self._update_status( + status.uuid, + is_healthy=True, + failure_count=0, + health_check_last_failed=False, + last_success_time=datetime.now(), + last_health_check_time=datetime.now(), + last_error_message=None, + ) + self.ap.logger.info( + f'API recovered: provider={status.provider_uuid} ' + f'model={status.model_name} key_index={status.api_key_index}' + ) + else: + # Health check probe failed: mark the flag but do NOT increment failure_count + await self._update_status( + status.uuid, + health_check_last_failed=True, + last_health_check_time=datetime.now(), + ) + except Exception as e: + self.ap.logger.error( + f'Health check loop error for provider={status.provider_uuid} ' + f'model={status.model_name} key_index={status.api_key_index}: {e}' + ) + try: + await self._update_status( + status.uuid, + health_check_last_failed=True, + last_health_check_time=datetime.now(), + ) + except Exception: + pass + + async def _check_api_health( + self, + status: api_chain_entity.APIChainStatus, + provider: requester.RuntimeProvider, + ) -> bool: + """Check API health by making a minimal test request to the LLM endpoint. + + Returns True if the request succeeds (API is reachable and authenticated), + False otherwise. Does NOT raise exceptions. + """ + try: + temp_provider = self._create_provider_for_key(provider, status.api_key_index) + model_entity = self._resolve_model_entity(provider, None, status.model_name) + if model_entity is None: + self.ap.logger.warning( + f'Health check: no model found for provider={status.provider_uuid} ' + f'model_name={status.model_name}, skipping' + ) + return False + + temp_model = requester.RuntimeLLMModel( + model_entity=model_entity, + provider=temp_provider, + ) + test_msg = provider_message.Message(role='user', content='hi') + + await temp_provider.invoke_llm( + query=None, + model=temp_model, + messages=[test_msg], + funcs=None, + extra_args={}, + remove_think=True, + ) + return True + except Exception as e: + self.ap.logger.debug( + f'Health check request failed for provider={status.provider_uuid} ' + f'model={status.model_name} key={status.api_key_index}: {e}' + ) + return False + + # ==================== Status Helpers ==================== + + async def _ensure_status( + self, + chain_uuid: str, + provider_uuid: str, + model_name: Optional[str], + api_key_index: Optional[int], + ) -> api_chain_entity.APIChainStatus: + """Get or create a status record for the given (chain, provider, model, key) tuple""" + conditions = [ + api_chain_entity.APIChainStatus.chain_uuid == chain_uuid, + api_chain_entity.APIChainStatus.provider_uuid == provider_uuid, + ] + if model_name is None: + conditions.append(api_chain_entity.APIChainStatus.model_name == None) # noqa: E711 + else: + conditions.append(api_chain_entity.APIChainStatus.model_name == model_name) + + if api_key_index is None: + conditions.append(api_chain_entity.APIChainStatus.api_key_index == None) # noqa: E711 + else: + conditions.append(api_chain_entity.APIChainStatus.api_key_index == api_key_index) + + result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(api_chain_entity.APIChainStatus).where(sqlalchemy.and_(*conditions)) + ) + existing = result.first() + if existing: + return existing + + new_uuid = str(uuid_lib.uuid4()) + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(api_chain_entity.APIChainStatus).values( + uuid=new_uuid, + chain_uuid=chain_uuid, + provider_uuid=provider_uuid, + model_name=model_name, + api_key_index=api_key_index, + is_healthy=True, + failure_count=0, + ) + ) + result2 = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(api_chain_entity.APIChainStatus).where(api_chain_entity.APIChainStatus.uuid == new_uuid) + ) + return result2.first() + + async def _update_status( + self, + status_uuid: str, + is_healthy: Optional[bool] = None, + failure_count: Optional[int] = None, + last_failure_time: Optional[datetime] = None, + last_success_time: Optional[datetime] = None, + last_health_check_time: Optional[datetime] = None, + last_error_message: Optional[str] = None, + health_check_last_failed: Optional[bool] = None, + ): + """Update a status record by UUID""" + update_data: Dict[str, Any] = {} + if is_healthy is not None: + update_data['is_healthy'] = is_healthy + if failure_count is not None: + update_data['failure_count'] = failure_count + if last_failure_time is not None: + update_data['last_failure_time'] = last_failure_time + if last_success_time is not None: + update_data['last_success_time'] = last_success_time + if last_health_check_time is not None: + update_data['last_health_check_time'] = last_health_check_time + if last_error_message is not None: + update_data['last_error_message'] = last_error_message + elif is_healthy: + # Clear error message when marking healthy + update_data['last_error_message'] = None + if health_check_last_failed is not None: + update_data['health_check_last_failed'] = health_check_last_failed + + if update_data: + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(api_chain_entity.APIChainStatus) + .where(api_chain_entity.APIChainStatus.uuid == status_uuid) + .values(**update_data) + ) + + # ==================== Chain CRUD ==================== + + async def get_chain(self, chain_uuid: str) -> Optional[api_chain_entity.APIChain]: + """Get an API chain by UUID""" + return self.chains.get(chain_uuid) + + async def create_chain(self, chain_data: Dict[str, Any]) -> str: + """Create a new API chain and start its health check loop""" + chain_uuid = chain_data.get('uuid', str(uuid_lib.uuid4())) + + chain = api_chain_entity.APIChain( + uuid=chain_uuid, + name=chain_data['name'], + description=chain_data.get('description', ''), + chain_config=chain_data.get('chain_config', []), + health_check_interval=chain_data.get('health_check_interval', 300), + health_check_enabled=chain_data.get('health_check_enabled', True), + ) + + # Use explicit column values to avoid SQLAlchemy internal state pollution + await self.ap.persistence_mgr.execute_async( + sqlalchemy.insert(api_chain_entity.APIChain).values( + uuid=chain.uuid, + name=chain.name, + description=chain.description, + chain_config=chain.chain_config, + health_check_interval=chain.health_check_interval, + health_check_enabled=chain.health_check_enabled, + ) + ) + self.chains[chain_uuid] = chain + + if chain.health_check_enabled: + task = asyncio.create_task(self._health_check_loop(chain_uuid)) + self.health_check_tasks[chain_uuid] = task + + return chain_uuid + + async def update_chain(self, chain_uuid: str, chain_data: Dict[str, Any]): + """Update an existing API chain""" + existing = self.chains.get(chain_uuid) + + # Collect current attribute values (may come from an in-memory instance or DB) + if existing is not None: + current = { + 'uuid': existing.uuid, + 'name': existing.name, + 'description': existing.description, + 'chain_config': existing.chain_config, + 'health_check_interval': existing.health_check_interval, + 'health_check_enabled': existing.health_check_enabled, + } + else: + db_result = await self.ap.persistence_mgr.execute_async( + sqlalchemy.select(api_chain_entity.APIChain).where(api_chain_entity.APIChain.uuid == chain_uuid) + ) + row = db_result.first() + if not row: + raise ValueError(f'Chain {chain_uuid} not found') + current = { + 'uuid': row.uuid, + 'name': row.name, + 'description': row.description, + 'chain_config': row.chain_config, + 'health_check_interval': row.health_check_interval, + 'health_check_enabled': row.health_check_enabled, + } + + # Merge incoming changes + for key, value in chain_data.items(): + if key in current and key != 'uuid': + current[key] = value + + # Persist changes to DB + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(api_chain_entity.APIChain) + .where(api_chain_entity.APIChain.uuid == chain_uuid) + .values(**{k: v for k, v in chain_data.items() if k != 'uuid'}) + ) + + # Rebuild mutable in-memory instance with merged data + new_chain = api_chain_entity.APIChain( + uuid=current['uuid'], + name=current['name'], + description=current.get('description', ''), + chain_config=current.get('chain_config', []), + health_check_interval=current.get('health_check_interval', 300), + health_check_enabled=current.get('health_check_enabled', True), + ) + self.chains[chain_uuid] = new_chain + + # Cancel existing task and restart to pick up new config immediately + existing_task = self.health_check_tasks.pop(chain_uuid, None) + if existing_task is not None: + existing_task.cancel() + if new_chain.health_check_enabled: + task = asyncio.create_task(self._health_check_loop(chain_uuid)) + self.health_check_tasks[chain_uuid] = task + + async def delete_chain(self, chain_uuid: str): + """Delete an API chain""" + if chain_uuid in self.health_check_tasks: + self.health_check_tasks[chain_uuid].cancel() + del self.health_check_tasks[chain_uuid] + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(api_chain_entity.APIChain).where(api_chain_entity.APIChain.uuid == chain_uuid) + ) + await self.ap.persistence_mgr.execute_async( + sqlalchemy.delete(api_chain_entity.APIChainStatus).where( + api_chain_entity.APIChainStatus.chain_uuid == chain_uuid + ) + ) + self.chains.pop(chain_uuid, None) + + # ==================== Invoke Helpers ==================== + + def _build_invoke_tasks( + self, + model_configs: List[Dict[str, Any]], + ) -> List[Tuple[Optional[str], Optional[int]]]: + """Build an ordered list of (model_name, api_key_index) tuples to try. + + Returns [(None, None)] when model_configs is empty, meaning the caller's + original model and round-robin key rotation will be used (legacy behaviour). + When api_key_indices is configured for a model, each key index becomes a + separate failover task ordered by priority. When api_key_indices is absent, + the entry uses (model_name, None) so the provider's TokenManager performs + round-robin rotation across all configured keys. + """ + if not model_configs: + return [(None, None)] + + tasks: List[Tuple[Optional[str], Optional[int]]] = [] + sorted_models = sorted(model_configs, key=lambda x: x.get('priority', 0)) + for mc in sorted_models: + model_name: Optional[str] = mc.get('model_name') or None + api_key_indices: List[Dict] = mc.get('api_key_indices') or [] + if api_key_indices: + # Expand each configured key index as an independent failover task + sorted_keys = sorted(api_key_indices, key=lambda x: x.get('priority', 0)) + for key_config in sorted_keys: + tasks.append((model_name, key_config['index'])) + else: + # No specific key configured: use round-robin rotation + tasks.append((model_name, None)) + return tasks if tasks else [(None, None)] + + def _create_provider_for_key( + self, + provider: requester.RuntimeProvider, + api_key_index: Optional[int], + ) -> requester.RuntimeProvider: + """Return a provider restricted to a single API key. + + Creates a lightweight wrapper with a single-token TokenManager so that + shared mutable state on the original provider is not modified. + Returns the original provider unchanged when api_key_index is None. + """ + if api_key_index is None: + return provider + + tokens = provider.token_mgr.tokens + if not tokens or api_key_index >= len(tokens): + return provider # index out of range 鈥?fall back gracefully + + single_token_mgr = token.TokenManager( + name=provider.token_mgr.name, + tokens=[tokens[api_key_index]], + ) + return requester.RuntimeProvider( + provider_entity=provider.provider_entity, + token_mgr=single_token_mgr, + requester=provider.requester, + ) + + def _resolve_model_entity( + self, + provider: requester.RuntimeProvider, + default_model: Optional[requester.RuntimeLLMModel], + model_name: Optional[str], + ) -> Any: + """Return the model entity for model_name under the given provider. + + Falls back to default_model.model_entity when model_name is None or no + matching model is found. When default_model is also None, falls back to + the first available model for the provider. + """ + if not model_name: + if default_model is not None: + return default_model.model_entity + for m in self.ap.model_mgr.llm_models: + if m.model_entity.provider_uuid == provider.provider_entity.uuid: + return m.model_entity + return None + + for m in self.ap.model_mgr.llm_models: + if m.model_entity.provider_uuid == provider.provider_entity.uuid and m.model_entity.name == model_name: + return m.model_entity + + if default_model is not None: + return default_model.model_entity + for m in self.ap.model_mgr.llm_models: + if m.model_entity.provider_uuid == provider.provider_entity.uuid: + return m.model_entity + return None + + # ==================== LLM Invocation ==================== + + async def invoke_chain_llm( + self, + chain_uuid: str, + query: pipeline_query.Query, + model: Optional[requester.RuntimeLLMModel], + messages: List[provider_message.Message], + funcs: Optional[List[resource_tool.LLMTool]] = None, + extra_args: Dict[str, Any] = {}, + remove_think: bool = False, + ) -> provider_message.Message: + """Invoke LLM through API chain with per-model/per-API-key failover""" + chain = self.chains.get(chain_uuid) + if not chain: + raise ValueError(f'Chain {chain_uuid} not found') + + sorted_items = sorted(chain.chain_config, key=lambda x: x.get('priority', 0)) + last_error: Optional[Exception] = None + + for item in sorted_items: + provider_uuid: str = item['provider_uuid'] + is_aggregated: bool = item.get('is_aggregated', False) + max_retries: int = item.get('max_retries', 3) + model_configs: List[Dict] = item.get('model_configs') or [] + + provider = self.ap.model_mgr.provider_dict.get(provider_uuid) + if not provider: + self.ap.logger.warning(f'Provider {provider_uuid} not found in chain {chain_uuid}') + continue + + tasks = self._build_invoke_tasks(model_configs) + + for task_model_name, task_api_key_index in tasks: + status = await self._ensure_status(chain_uuid, provider_uuid, task_model_name, task_api_key_index) + + if status and not status.is_healthy and not is_aggregated: + self.ap.logger.debug( + f'Skipping unhealthy: provider={provider_uuid} model={task_model_name} key={task_api_key_index}' + ) + continue + + temp_provider = self._create_provider_for_key(provider, task_api_key_index) + model_entity = self._resolve_model_entity(provider, model, task_model_name) + if model_entity is None: + self.ap.logger.warning( + f'No model found for provider {provider_uuid} in chain {chain_uuid}, skipping' + ) + continue + temp_model = requester.RuntimeLLMModel( + model_entity=model_entity, + provider=temp_provider, + ) + + retry_count = 0 if is_aggregated else max_retries + + for attempt in range(max(1, retry_count + 1)): + try: + result = await temp_provider.invoke_llm( + query=query, + model=temp_model, + messages=messages, + funcs=funcs, + extra_args=extra_args, + remove_think=remove_think, + ) + + # Advance round-robin token rotation on success + if task_api_key_index is None: + provider.token_mgr.next_token() + + if status: + await self._update_status( + status.uuid, + is_healthy=True, + failure_count=0, + health_check_last_failed=False, + last_success_time=datetime.now(), + ) + return result + + except Exception as e: + last_error = e + self.ap.logger.warning( + f'Chain {chain_uuid} provider={provider_uuid} ' + f'model={task_model_name} key={task_api_key_index} ' + f'attempt {attempt + 1}/{max(1, retry_count + 1)} failed: {e}' + ) + # Advance round-robin token rotation on failure too + if task_api_key_index is None: + provider.token_mgr.next_token() + + if attempt + 1 >= max(1, retry_count + 1): + # All retries exhausted for this (model, key) task + if is_aggregated: + # Aggregated: track failure count but keep is_healthy=True + if status: + await self._update_status( + status.uuid, + failure_count=(status.failure_count or 0) + 1, + last_failure_time=datetime.now(), + last_error_message=str(e)[:1024], + ) + else: + if status: + await self._update_status( + status.uuid, + is_healthy=False, + failure_count=(status.failure_count or 0) + 1, + health_check_last_failed=False, + last_failure_time=datetime.now(), + last_error_message=str(e)[:1024], + ) + break # Move to next (model_name, key_index) task + + error_msg = f'All providers in chain {chain_uuid} failed' + if last_error: + error_msg += f': {last_error}' + raise Exception(error_msg) + + async def invoke_chain_llm_stream( + self, + chain_uuid: str, + query: pipeline_query.Query, + model: Optional[requester.RuntimeLLMModel], + messages: List[provider_message.Message], + funcs: Optional[List[resource_tool.LLMTool]] = None, + extra_args: Dict[str, Any] = {}, + remove_think: bool = False, + ) -> AsyncGenerator[provider_message.MessageChunk, None]: + """Invoke LLM stream through API chain with per-model/per-API-key failover""" + chain = self.chains.get(chain_uuid) + if not chain: + raise ValueError(f'Chain {chain_uuid} not found') + + sorted_items = sorted(chain.chain_config, key=lambda x: x.get('priority', 0)) + last_error: Optional[Exception] = None + # True if the stream started yielding and then failed mid-flight. + # In this case we must NOT fall through to the next provider because + # partial output has already been sent to the caller. + failed_mid_stream: bool = False + + for item in sorted_items: + provider_uuid: str = item['provider_uuid'] + is_aggregated: bool = item.get('is_aggregated', False) + max_retries: int = item.get('max_retries', 3) + model_configs: List[Dict] = item.get('model_configs') or [] + + provider = self.ap.model_mgr.provider_dict.get(provider_uuid) + if not provider: + self.ap.logger.warning(f'Provider {provider_uuid} not found in chain {chain_uuid}') + continue + + tasks = self._build_invoke_tasks(model_configs) + + for task_model_name, task_api_key_index in tasks: + status = await self._ensure_status(chain_uuid, provider_uuid, task_model_name, task_api_key_index) + + if status and not status.is_healthy and not is_aggregated: + self.ap.logger.debug( + f'Skipping unhealthy: provider={provider_uuid} model={task_model_name} key={task_api_key_index}' + ) + continue + + temp_provider = self._create_provider_for_key(provider, task_api_key_index) + model_entity = self._resolve_model_entity(provider, model, task_model_name) + if model_entity is None: + self.ap.logger.warning( + f'No model found for provider {provider_uuid} in chain {chain_uuid}, skipping' + ) + continue + temp_model = requester.RuntimeLLMModel( + model_entity=model_entity, + provider=temp_provider, + ) + + retry_count = 0 if is_aggregated else max_retries + + for attempt in range(max(1, retry_count + 1)): + has_yielded = False + try: + async for chunk in temp_provider.invoke_llm_stream( + query=query, + model=temp_model, + messages=messages, + funcs=funcs, + extra_args=extra_args, + remove_think=remove_think, + ): + has_yielded = True + yield chunk + + # Advance round-robin token rotation on success + if task_api_key_index is None: + provider.token_mgr.next_token() + + if status: + await self._update_status( + status.uuid, + is_healthy=True, + failure_count=0, + health_check_last_failed=False, + last_success_time=datetime.now(), + ) + return + + except Exception as e: + last_error = e + self.ap.logger.warning( + f'Chain {chain_uuid} provider={provider_uuid} ' + f'model={task_model_name} key={task_api_key_index} ' + f'stream attempt {attempt + 1}/{max(1, retry_count + 1)} failed: {e}' + ) + # Advance round-robin token rotation on failure too + if task_api_key_index is None: + provider.token_mgr.next_token() + + if has_yielded or attempt + 1 >= max(1, retry_count + 1): + # Cannot retry if chunks were already yielded (would duplicate output), + # or all retries are exhausted for this task + if is_aggregated: + # Aggregated: track failure count but keep is_healthy=True + if status: + await self._update_status( + status.uuid, + failure_count=(status.failure_count or 0) + 1, + last_failure_time=datetime.now(), + last_error_message=str(e)[:1024], + ) + else: + if status: + await self._update_status( + status.uuid, + is_healthy=False, + failure_count=(status.failure_count or 0) + 1, + health_check_last_failed=False, + last_failure_time=datetime.now(), + last_error_message=str(e)[:1024], + ) + if has_yielded: + # Partial output already sent to caller; stop processing + # entirely to avoid mixing output from different providers. + failed_mid_stream = True + break # Move to next (model_name, key_index) task + if failed_mid_stream: + break # Exit for-task loop + if failed_mid_stream: + break # Exit for-item loop + + if failed_mid_stream and last_error: + # Re-raise original exception; caller already received partial chunks + raise last_error + + error_msg = f'All providers in chain {chain_uuid} failed' + if last_error: + error_msg += f': {last_error}' + raise Exception(error_msg) diff --git a/src/langbot/pkg/provider/runners/localagent.py b/src/langbot/pkg/provider/runners/localagent.py index 52e78b9d4..a4ccbd8ee 100644 --- a/src/langbot/pkg/provider/runners/localagent.py +++ b/src/langbot/pkg/provider/runners/localagent.py @@ -119,23 +119,37 @@ async def run( remove_think = query.pipeline_config['output'].get('misc', '').get('remove-think') - use_llm_model = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid) + use_api_chain_uuid = (query.variables or {}).get('_use_api_chain_uuid') + use_llm_model = ( + None if use_api_chain_uuid else await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid) + ) self.ap.logger.debug( - f'localagent req: query={query.query_id} req_messages={req_messages} use_llm_model={query.use_llm_model_uuid}' + f'localagent req: query={query.query_id} req_messages={req_messages} ' + f'use_llm_model={query.use_llm_model_uuid} use_api_chain={use_api_chain_uuid}' ) if not is_stream: # 非流式输出,直接请求 - - msg = await use_llm_model.provider.invoke_llm( - query, - use_llm_model, - req_messages, - query.use_funcs, - extra_args=use_llm_model.model_entity.extra_args, - remove_think=remove_think, - ) + if use_api_chain_uuid: + msg = await self.ap.api_chain_mgr.invoke_chain_llm( + use_api_chain_uuid, + query, + None, + req_messages, + query.use_funcs, + extra_args={}, + remove_think=remove_think, + ) + else: + msg = await use_llm_model.provider.invoke_llm( + query, + use_llm_model, + req_messages, + query.use_funcs, + extra_args=use_llm_model.model_entity.extra_args, + remove_think=remove_think, + ) yield msg final_msg = msg else: @@ -145,14 +159,26 @@ async def run( accumulated_content = '' # 从开始累积的所有内容 last_role = 'assistant' msg_sequence = 1 - async for msg in use_llm_model.provider.invoke_llm_stream( - query, - use_llm_model, - req_messages, - query.use_funcs, - extra_args=use_llm_model.model_entity.extra_args, - remove_think=remove_think, - ): + if use_api_chain_uuid: + stream_src = self.ap.api_chain_mgr.invoke_chain_llm_stream( + use_api_chain_uuid, + query, + None, + req_messages, + query.use_funcs, + extra_args={}, + remove_think=remove_think, + ) + else: + stream_src = use_llm_model.provider.invoke_llm_stream( + query, + use_llm_model, + req_messages, + query.use_funcs, + extra_args=use_llm_model.model_entity.extra_args, + remove_think=remove_think, + ) + async for msg in stream_src: msg_idx = msg_idx + 1 # 记录角色 @@ -253,7 +279,8 @@ async def run( req_messages.append(err_msg) self.ap.logger.debug( - f'localagent req: query={query.query_id} req_messages={req_messages} use_llm_model={query.use_llm_model_uuid}' + f'localagent req: query={query.query_id} req_messages={req_messages} ' + f'use_llm_model={query.use_llm_model_uuid} use_api_chain={use_api_chain_uuid}' ) if is_stream: @@ -263,14 +290,26 @@ async def run( last_role = 'assistant' msg_sequence = first_end_sequence - async for msg in use_llm_model.provider.invoke_llm_stream( - query, - use_llm_model, - req_messages, - query.use_funcs, - extra_args=use_llm_model.model_entity.extra_args, - remove_think=remove_think, - ): + if use_api_chain_uuid: + tool_stream_src = self.ap.api_chain_mgr.invoke_chain_llm_stream( + use_api_chain_uuid, + query, + None, + req_messages, + query.use_funcs, + extra_args={}, + remove_think=remove_think, + ) + else: + tool_stream_src = use_llm_model.provider.invoke_llm_stream( + query, + use_llm_model, + req_messages, + query.use_funcs, + extra_args=use_llm_model.model_entity.extra_args, + remove_think=remove_think, + ) + async for msg in tool_stream_src: msg_idx += 1 # 记录角色 @@ -319,14 +358,25 @@ async def run( ) else: # 处理完所有调用,再次请求 - msg = await use_llm_model.provider.invoke_llm( - query, - use_llm_model, - req_messages, - query.use_funcs, - extra_args=use_llm_model.model_entity.extra_args, - remove_think=remove_think, - ) + if use_api_chain_uuid: + msg = await self.ap.api_chain_mgr.invoke_chain_llm( + use_api_chain_uuid, + query, + None, + req_messages, + query.use_funcs, + extra_args={}, + remove_think=remove_think, + ) + else: + msg = await use_llm_model.provider.invoke_llm( + query, + use_llm_model, + req_messages, + query.use_funcs, + extra_args=use_llm_model.model_entity.extra_args, + remove_think=remove_think, + ) yield msg final_msg = msg diff --git a/src/langbot/pkg/provider/runners/n8nsvapi.py b/src/langbot/pkg/provider/runners/n8nsvapi.py index d177d6b81..d7ec3ccbf 100644 --- a/src/langbot/pkg/provider/runners/n8nsvapi.py +++ b/src/langbot/pkg/provider/runners/n8nsvapi.py @@ -5,8 +5,6 @@ import uuid import aiohttp -from langbot.pkg.utils import httpclient - from .. import runner from ...core import app import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @@ -219,50 +217,50 @@ async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenera self.ap.logger.debug('no auth') # 调用webhook - session = httpclient.get_session() - if is_stream: - # 流式请求 - async with session.post( - self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout - ) as response: - if response.status != 200: - error_text = await response.text() - self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') - raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') - - # 处理流式响应 - async for chunk in self._process_stream_response(response): - yield chunk - else: - async with session.post( - self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout - ) as response: - try: - async for chunk in self._process_stream_response(response): - output_content = chunk.content if chunk.is_final else '' - except: - # 非流式请求(保持原有逻辑) + async with aiohttp.ClientSession() as session: + if is_stream: + # 流式请求 + async with session.post( + self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout + ) as response: if response.status != 200: error_text = await response.text() self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') - # 解析响应 - response_data = await response.json() - self.ap.logger.debug(f'n8n webhook response: {response_data}') - - # 从响应中提取输出 - if self.output_key in response_data: - output_content = response_data[self.output_key] - else: - # 如果没有指定的输出键,则使用整个响应 - output_content = json.dumps(response_data, ensure_ascii=False) - - # 返回消息 - yield provider_message.Message( - role='assistant', - content=output_content, - ) + # 处理流式响应 + async for chunk in self._process_stream_response(response): + yield chunk + else: + async with session.post( + self.webhook_url, json=payload, headers=headers, auth=auth, timeout=self.timeout + ) as response: + try: + async for chunk in self._process_stream_response(response): + output_content = chunk.content if chunk.is_final else '' + except: + # 非流式请求(保持原有逻辑) + if response.status != 200: + error_text = await response.text() + self.ap.logger.error(f'n8n webhook call failed: {response.status}, {error_text}') + raise Exception(f'n8n webhook call failed: {response.status}, {error_text}') + + # 解析响应 + response_data = await response.json() + self.ap.logger.debug(f'n8n webhook response: {response_data}') + + # 从响应中提取输出 + if self.output_key in response_data: + output_content = response_data[self.output_key] + else: + # 如果没有指定的输出键,则使用整个响应 + output_content = json.dumps(response_data, ensure_ascii=False) + + # 返回消息 + yield provider_message.Message( + role='assistant', + content=output_content, + ) except Exception as e: self.ap.logger.error(f'n8n webhook call exception: {str(e)}') raise N8nAPIError(f'n8n webhook call exception: {str(e)}') diff --git a/src/langbot/pkg/utils/image.py b/src/langbot/pkg/utils/image.py index 5716b07d6..e07caec67 100644 --- a/src/langbot/pkg/utils/image.py +++ b/src/langbot/pkg/utils/image.py @@ -5,8 +5,6 @@ import ssl import aiohttp - -from langbot.pkg.utils import httpclient import PIL.Image import httpx @@ -49,54 +47,53 @@ async def get_gewechat_image_base64( ) try: - session = httpclient.get_session() - # 获取图片下载链接 - try: - async with session.post( - f'{gewechat_url}/v2/api/message/downloadImage', - headers=headers, - json={'appId': app_id, 'type': image_type, 'xml': xml_content}, - timeout=timeout, - ) as response: - if response.status != 200: - # print(response) - raise Exception(f'获取gewechat图片下载失败: {await response.text()}') - - resp_data = await response.json() - if resp_data.get('ret') != 200: - raise Exception(f'获取gewechat图片下载链接失败: {resp_data}') - - file_url = resp_data['data']['fileUrl'] - except asyncio.TimeoutError: - raise Exception('获取图片下载链接超时') - except aiohttp.ClientError as e: - raise Exception(f'获取图片下载链接网络错误: {str(e)}') - - # 解析原始URL并替换端口 - base_url = gewechat_file_url - download_url = f'{base_url}/download/{file_url}' - - # 下载图片 - try: - async with session.get(download_url) as img_response: - if img_response.status != 200: - raise Exception(f'下载图片失败: {await img_response.text()}, URL: {download_url}') - - image_data = await img_response.read() - - content_type = img_response.headers.get('Content-Type', '') - if content_type: - image_format = content_type.split('/')[-1] - else: - image_format = file_url.split('.')[-1] - - base64_str = base64.b64encode(image_data).decode('utf-8') - - return base64_str, image_format - except asyncio.TimeoutError: - raise Exception(f'下载图片超时, URL: {download_url}') - except aiohttp.ClientError as e: - raise Exception(f'下载图片网络错误: {str(e)}, URL: {download_url}') + async with aiohttp.ClientSession(timeout=timeout) as session: + # 获取图片下载链接 + try: + async with session.post( + f'{gewechat_url}/v2/api/message/downloadImage', + headers=headers, + json={'appId': app_id, 'type': image_type, 'xml': xml_content}, + ) as response: + if response.status != 200: + # print(response) + raise Exception(f'获取gewechat图片下载失败: {await response.text()}') + + resp_data = await response.json() + if resp_data.get('ret') != 200: + raise Exception(f'获取gewechat图片下载链接失败: {resp_data}') + + file_url = resp_data['data']['fileUrl'] + except asyncio.TimeoutError: + raise Exception('获取图片下载链接超时') + except aiohttp.ClientError as e: + raise Exception(f'获取图片下载链接网络错误: {str(e)}') + + # 解析原始URL并替换端口 + base_url = gewechat_file_url + download_url = f'{base_url}/download/{file_url}' + + # 下载图片 + try: + async with session.get(download_url) as img_response: + if img_response.status != 200: + raise Exception(f'下载图片失败: {await img_response.text()}, URL: {download_url}') + + image_data = await img_response.read() + + content_type = img_response.headers.get('Content-Type', '') + if content_type: + image_format = content_type.split('/')[-1] + else: + image_format = file_url.split('.')[-1] + + base64_str = base64.b64encode(image_data).decode('utf-8') + + return base64_str, image_format + except asyncio.TimeoutError: + raise Exception(f'下载图片超时, URL: {download_url}') + except aiohttp.ClientError as e: + raise Exception(f'下载图片网络错误: {str(e)}, URL: {download_url}') except Exception as e: raise Exception(f'获取图片失败: {str(e)}') from e @@ -107,24 +104,24 @@ async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]: :param pic_url: 企业微信图片URL :return: (base64_str, image_format) """ - session = httpclient.get_session() - async with session.get(pic_url) as response: - if response.status != 200: - raise Exception(f'Failed to download image: {response.status}') + async with aiohttp.ClientSession() as session: + async with session.get(pic_url) as response: + if response.status != 200: + raise Exception(f'Failed to download image: {response.status}') - # 读取图片数据 - image_data = await response.read() + # 读取图片数据 + image_data = await response.read() - # 获取图片格式 - content_type = response.headers.get('Content-Type', '') - image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg' + # 获取图片格式 + content_type = response.headers.get('Content-Type', '') + image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg' - # 转换为 base64 - import base64 + # 转换为 base64 + import base64 - image_base64 = base64.b64encode(image_data).decode('utf-8') + image_base64 = base64.b64encode(image_data).decode('utf-8') - return image_base64, image_format + return image_base64, image_format async def get_qq_official_image_base64(pic_url: str, content_type: str) -> tuple[str, str]: @@ -155,19 +152,21 @@ async def get_qq_image_bytes(image_url: str, query: dict = {}) -> tuple[bytes, s ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE - session = httpclient.get_session() - async with session.get(image_url, params=query, ssl=ssl_context, timeout=aiohttp.ClientTimeout(total=30.0)) as resp: - resp.raise_for_status() - file_bytes = await resp.read() - content_type = resp.headers.get('Content-Type') - if not content_type: - image_format = 'jpeg' - elif not content_type.startswith('image/'): - pil_img = PIL.Image.open(io.BytesIO(file_bytes)) - image_format = pil_img.format.lower() - else: - image_format = content_type.split('/')[-1] - return file_bytes, image_format + async with aiohttp.ClientSession(trust_env=False) as session: + async with session.get( + image_url, params=query, ssl=ssl_context, timeout=aiohttp.ClientTimeout(total=30.0) + ) as resp: + resp.raise_for_status() + file_bytes = await resp.read() + content_type = resp.headers.get('Content-Type') + if not content_type: + image_format = 'jpeg' + elif not content_type.startswith('image/'): + pil_img = PIL.Image.open(io.BytesIO(file_bytes)) + image_format = pil_img.format.lower() + else: + image_format = content_type.split('/')[-1] + return file_bytes, image_format async def qq_image_url_to_base64(image_url: str) -> typing.Tuple[str, str]: @@ -205,11 +204,11 @@ async def extract_b64_and_format(image_base64_data: str) -> typing.Tuple[str, st async def get_slack_image_to_base64(pic_url: str, bot_token: str): headers = {'Authorization': f'Bearer {bot_token}'} try: - session = httpclient.get_session() - async with session.get(pic_url, headers=headers) as resp: - mime_type = resp.headers.get('Content-Type', 'application/octet-stream') - file_bytes = await resp.read() - base64_str = base64.b64encode(file_bytes).decode('utf-8') - return f'data:{mime_type};base64,{base64_str}' + async with aiohttp.ClientSession() as session: + async with session.get(pic_url, headers=headers) as resp: + mime_type = resp.headers.get('Content-Type', 'application/octet-stream') + file_bytes = await resp.read() + base64_str = base64.b64encode(file_bytes).decode('utf-8') + return f'data:{mime_type};base64,{base64_str}' except Exception as e: raise (e) diff --git a/src/langbot/templates/metadata/pipeline/ai.yaml b/src/langbot/templates/metadata/pipeline/ai.yaml index 7a13b2b14..974684475 100644 --- a/src/langbot/templates/metadata/pipeline/ai.yaml +++ b/src/langbot/templates/metadata/pipeline/ai.yaml @@ -57,9 +57,9 @@ stages: config: - name: model label: - en_US: Model - zh_Hans: 模型 - type: llm-model-selector + en_US: Model / API Chain + zh_Hans: 模型 / API 链 + type: model-or-api-chain-selector required: true - name: max-round label: diff --git a/web/eslint.config.mjs b/web/eslint.config.mjs index 18b74c95d..d9a1d0e10 100644 --- a/web/eslint.config.mjs +++ b/web/eslint.config.mjs @@ -11,6 +11,9 @@ const compat = new FlatCompat({ }); const eslintConfig = [ + { + ignores: ['.next/**', 'node_modules/**'], + }, ...compat.extends('next/core-web-vitals', 'next/typescript'), eslintPluginPrettierRecommended, ]; diff --git a/web/src/app/home/bots/components/bot-form/BotForm.tsx b/web/src/app/home/bots/components/bot-form/BotForm.tsx index d2ca22d79..1e2439575 100644 --- a/web/src/app/home/bots/components/bot-form/BotForm.tsx +++ b/web/src/app/home/bots/components/bot-form/BotForm.tsx @@ -124,12 +124,6 @@ export default function BotForm({ const currentAdapter = form.watch('adapter'); const currentAdapterConfig = form.watch('adapter_config'); - // Serialize adapter_config to a stable string so it can be used as a - // useEffect dependency without triggering on every render. form.watch() - // returns a new object reference each time, which would otherwise cause - // the filtering effect below to loop indefinitely. - const adapterConfigJson = JSON.stringify(currentAdapterConfig); - useEffect(() => { setBotFormValues(); }, []); @@ -153,7 +147,7 @@ export default function BotForm({ // For non-Lark adapters, show all fields setFilteredDynamicFormConfigList(dynamicFormConfigList); } - }, [currentAdapter, adapterConfigJson, dynamicFormConfigList]); + }, [currentAdapter, currentAdapterConfig, dynamicFormConfigList]); // 复制到剪贴板的辅助函数 - 使用页面上的真实input元素 const copyToClipboard = () => { diff --git a/web/src/app/home/components/api-chains-dialog/APIChainCard.tsx b/web/src/app/home/components/api-chains-dialog/APIChainCard.tsx new file mode 100644 index 000000000..9d7cb3e27 --- /dev/null +++ b/web/src/app/home/components/api-chains-dialog/APIChainCard.tsx @@ -0,0 +1,495 @@ +'use client'; + +import { + APIChain, + APIChainStatus, + LLMModel, + ModelProvider, +} from '@/app/infra/entities/api'; +import { Card, CardContent } from '@/components/ui/card'; +import { Button } from '@/components/ui/button'; +import { + Edit, + Trash2, + ChevronDown, + ChevronRight, + AlertCircle, + CheckCircle2, + AlertTriangle, +} from 'lucide-react'; +import { useState, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { cn } from '@/lib/utils'; + +interface APIChainCardProps { + chain: APIChain; + providers: ModelProvider[]; + llmModels: LLMModel[]; + onEdit: () => void; + onDelete: () => void; +} + +function calculateHealthPercentage( + statuses: APIChainStatus[] | undefined, + chainConfig: APIChain['chain_config'], +): number { + if (!statuses || statuses.length === 0 || chainConfig.length === 0) + return 100; + // Collect UUIDs of aggregated providers so their statuses are excluded from + // the unhealthy calculation (they never become unhealthy by design). + const aggregatedUuids = new Set( + chainConfig.filter((c) => c.is_aggregated).map((c) => c.provider_uuid), + ); + const trackable = statuses.filter( + (s) => !aggregatedUuids.has(s.provider_uuid), + ); + if (trackable.length === 0) return 100; + const healthyCount = trackable.filter((s) => s.is_healthy).length; + return Math.round((healthyCount / trackable.length) * 100); +} + +function getErrorStats(statuses: APIChainStatus[] | undefined): { + totalFailures: number; +} { + if (!statuses || statuses.length === 0) return { totalFailures: 0 }; + return { + totalFailures: statuses.reduce((sum, s) => sum + (s.failure_count || 0), 0), + }; +} + +function getHealthColorClass(healthPercentage: number): string { + if (healthPercentage === 0) return 'border-destructive bg-destructive/5'; + if (healthPercentage < 50) return 'border-yellow-500 bg-yellow-500/5'; + return ''; +} + +function getHealthIcon(healthPercentage: number) { + if (healthPercentage === 0) + return ; + if (healthPercentage < 50) + return ; + return ; +} + +export default function APIChainCard({ + chain, + providers, + llmModels: _llmModels, // eslint-disable-line @typescript-eslint/no-unused-vars + onEdit, + onDelete, +}: APIChainCardProps) { + const { t } = useTranslation(); + const [expanded, setExpanded] = useState(false); + + const getProviderName = (uuid: string) => + providers.find((p) => p.uuid === uuid)?.name ?? uuid; + + const sortedConfigs = [...chain.chain_config].sort( + (a, b) => a.priority - b.priority, + ); + + const healthPercentage = useMemo( + () => calculateHealthPercentage(chain.statuses, chain.chain_config), + [chain.statuses, chain.chain_config], + ); + + const { totalFailures } = useMemo( + () => getErrorStats(chain.statuses), + [chain.statuses], + ); + + /** Get all status records for a given (provider, model_name, api_key_index) combination */ + function getStatus( + providerUuid: string, + modelName: string | null, + apiKeyIndex: number | null, + ): APIChainStatus | undefined { + return chain.statuses?.find( + (s) => + s.provider_uuid === providerUuid && + (s.model_name ?? null) === modelName && + (s.api_key_index ?? null) === apiKeyIndex, + ); + } + + /** Get all status records for a provider (any granularity) */ + function getProviderStatuses(providerUuid: string): APIChainStatus[] { + return ( + chain.statuses?.filter((s) => s.provider_uuid === providerUuid) ?? [] + ); + } + + /** Compute provider-level health summary */ + function providerHealthSummary( + providerUuid: string, + isAggregated: boolean, + ): { healthy: number; total: number } { + // Aggregated providers never become unhealthy by design; always report full health. + if (isAggregated) return { healthy: 1, total: 1 }; + const ss = getProviderStatuses(providerUuid); + if (ss.length === 0) return { healthy: 1, total: 1 }; // assume healthy if no data + return { healthy: ss.filter((s) => s.is_healthy).length, total: ss.length }; + } + + return ( + + + {/* Header */} +
+
+
+ +

{chain.name}

+ {getHealthIcon(healthPercentage)} + + {healthPercentage}% + +
+ {chain.description && ( +

+ {chain.description} +

+ )} +
+ {t('apiChains.providerCount', { + count: chain.chain_config.length, + })} +
+
+
+ + +
+
+ + {/* Expanded: per-provider / per-model / per-key health */} + {expanded && ( +
+ {sortedConfigs.map((config, index) => { + const { healthy, total } = providerHealthSummary( + config.provider_uuid, + config.is_aggregated, + ); + const providerHealthy = healthy === total; + const modelConfigs = config.model_configs ?? []; + + return ( +
+ {/* Provider row */} +
+ + {index + 1}. {getProviderName(config.provider_uuid)} + + {config.is_aggregated && ( + + {t('apiChains.aggregation')} + + )} + + {providerHealthy + ? t('apiChains.healthy') + : `${healthy}/${total}`} + +
+ + {/* Per-model breakdown */} + {modelConfigs.length > 0 ? ( +
+ {[...modelConfigs] + .sort((a, b) => a.priority - b.priority) + .map((mc, mi) => { + // ── Aggregated mode ────────────────────────────────────────── + // Aggregated providers never become unhealthy and are not subject + // to health checks. After retries are exhausted the chain simply + // moves to the next model. Always render a green "healthy" badge + // and omit the per-key sub-list (not meaningful here). + if (config.is_aggregated) { + // Sum failure counts across all key variants for display. + const aggStatuses = [ + getStatus( + config.provider_uuid, + mc.model_name, + null, + ), + ...(mc.api_key_indices ?? []).map((k) => + getStatus( + config.provider_uuid, + mc.model_name, + k.index, + ), + ), + ].filter(Boolean) as APIChainStatus[]; + const aggFailures = aggStatuses.reduce( + (sum, s) => sum + (s.failure_count || 0), + 0, + ); + return ( +
+
+ + #{mi + 1} + + + {mc.model_name} + + + {t('apiChains.healthy')} + + {aggFailures > 0 && ( + + {t('apiChains.failureCount')}:{' '} + {aggFailures} + + )} +
+
+ ); + } + + // ── Non-aggregated mode ────────────────────────────────────── + // When api_key_indices are configured, look up the status for + // each specific key; otherwise fall back to the round-robin + // (api_key_index=null) record. + const apiKeyIndices = mc.api_key_indices ?? []; + const relevantStatuses: APIChainStatus[] = + apiKeyIndices.length > 0 + ? (apiKeyIndices + .map((k) => + getStatus( + config.provider_uuid, + mc.model_name, + k.index, + ), + ) + .filter(Boolean) as APIChainStatus[]) + : ([ + getStatus( + config.provider_uuid, + mc.model_name, + null, + ), + ].filter(Boolean) as APIChainStatus[]); + const modelHealthy = + relevantStatuses.length === 0 || + relevantStatuses.every((s) => s.is_healthy); + const totalFailures = relevantStatuses.reduce( + (sum, s) => sum + (s.failure_count || 0), + 0, + ); + const modelHcFailed = relevantStatuses.some( + (s) => + !s.is_healthy && !!s.health_check_last_failed, + ); + const modelLastError = relevantStatuses.find( + (s) => + s.last_error_message && + !s.health_check_last_failed, + )?.last_error_message; + return ( +
+
+ + #{mi + 1} + + + {mc.model_name} + + + {modelHealthy + ? t('apiChains.healthy') + : t('apiChains.unhealthy')} + + {/* health check failed badge */} + {!modelHealthy && modelHcFailed && ( + + {t('apiChains.healthCheckFailed')} + + )} + {totalFailures > 0 && ( + + {t('apiChains.failureCount')}:{' '} + {totalFailures} + + )} +
+ {modelLastError && ( +

+ {t('apiChains.lastError')}: {modelLastError} +

+ )} + {/* Per-key sub-list when api_key_indices are configured */} + {apiKeyIndices.length > 0 && ( +
+ {[...apiKeyIndices] + .sort((a, b) => a.priority - b.priority) + .map((k) => { + const kst = getStatus( + config.provider_uuid, + mc.model_name, + k.index, + ); + const keyHealthy = kst + ? kst.is_healthy + : true; + return ( +
+ + key[{k.index}] + + + {keyHealthy + ? t('apiChains.healthy') + : t('apiChains.unhealthy')} + + {kst && + !kst.is_healthy && + kst.health_check_last_failed && ( + + {t( + 'apiChains.healthCheckFailed', + )} + + )} + {kst && kst.failure_count > 0 && ( + + {kst.failure_count}× + + )} +
+ ); + })} +
+ )} +
+ ); + })} +
+ ) : ( + /* No model configs: show provider-level status detail */ +
+
+ {t('apiChains.maxRetries')}: {config.max_retries} +
+
+ {t('apiChains.timeout')}: {config.timeout_ms}ms +
+ {(() => { + const st = getStatus(config.provider_uuid, null, null); + if (!st) return null; + return ( + <> + {/* health check failed badge (only for non-aggregated unhealthy) */} + {!st.is_healthy && st.health_check_last_failed && ( +
+ ⚠ {t('apiChains.healthCheckFailed')} +
+ )} + {st.failure_count > 0 && ( +
+ {t('apiChains.failureCount')}:{' '} + {st.failure_count} +
+ )} + {st.last_error_message && + !st.health_check_last_failed && ( +
+ {t('apiChains.lastError')}:{' '} + {st.last_error_message} +
+ )} + + ); + })()} +
+ )} +
+ ); + })} +
+ )} + + {/* Bottom summary */} +
+ 0 && 'text-destructive')}> + {t('apiChains.errorCount')}: {totalFailures} + +
+
+
+ ); +} diff --git a/web/src/app/home/components/api-chains-dialog/APIChainForm.tsx b/web/src/app/home/components/api-chains-dialog/APIChainForm.tsx new file mode 100644 index 000000000..39568c956 --- /dev/null +++ b/web/src/app/home/components/api-chains-dialog/APIChainForm.tsx @@ -0,0 +1,665 @@ +'use client'; + +import { useState, useEffect } from 'react'; +import { httpClient } from '@/app/infra/http/HttpClient'; +import { + APIChainItem, + APIChainModelConfig, + LLMModel, + ModelProvider, +} from '@/app/infra/entities/api'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; +import { Textarea } from '@/components/ui/textarea'; +import { Switch } from '@/components/ui/switch'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { toast } from 'sonner'; +import { useTranslation } from 'react-i18next'; +import { + Plus, + Trash2, + GripVertical, + ChevronDown, + ChevronRight, +} from 'lucide-react'; + +interface APIChainFormProps { + chainId?: string; + providers: ModelProvider[]; + llmModels: LLMModel[]; + onFormSubmit: () => void; + onFormCancel: () => void; +} + +export default function APIChainForm({ + chainId, + providers, + llmModels, + onFormSubmit, + onFormCancel, +}: APIChainFormProps) { + const { t } = useTranslation(); + const [loading, setLoading] = useState(false); + const [name, setName] = useState(''); + const [description, setDescription] = useState(''); + const [chainConfig, setChainConfig] = useState([]); + const [healthCheckEnabled, setHealthCheckEnabled] = useState(true); + const [healthCheckInterval, setHealthCheckInterval] = useState(300); + /** Track which provider item has model_configs expanded */ + const [expandedModelConfigs, setExpandedModelConfigs] = useState>( + new Set(), + ); + /** Track which provider item has advanced config expanded */ + const [expandedAdvanced, setExpandedAdvanced] = useState>( + new Set(), + ); + + useEffect(() => { + if (chainId) { + loadChain(); + } else { + setChainConfig([ + { + provider_uuid: '', + priority: 1, + is_aggregated: false, + max_retries: 3, + timeout_ms: 30000, + model_configs: [], + }, + ]); + } + }, [chainId]); + + async function loadChain() { + try { + setLoading(true); + const resp = await httpClient.getAPIChain(chainId!); + setName(resp.chain.name); + setDescription(resp.chain.description || ''); + setChainConfig(resp.chain.chain_config); + setHealthCheckEnabled(resp.chain.health_check_enabled); + setHealthCheckInterval(resp.chain.health_check_interval); + } catch (err) { + toast.error(t('apiChains.loadError') + ': ' + (err as Error).message); + } finally { + setLoading(false); + } + } + + /** Return models belonging to the given provider */ + function modelsForProvider(providerUuid: string): LLMModel[] { + return llmModels.filter((m) => m.provider_uuid === providerUuid); + } + + // ---- Provider item CRUD ---- + + function addProvider() { + const maxPriority = Math.max(...chainConfig.map((c) => c.priority), 0); + setChainConfig([ + ...chainConfig, + { + provider_uuid: '', + priority: maxPriority + 1, + is_aggregated: false, + max_retries: 3, + timeout_ms: 30000, + model_configs: [], + }, + ]); + } + + function removeProvider(index: number) { + const newConfig = chainConfig.filter((_, i) => i !== index); + newConfig.forEach((c, i) => { + c.priority = i + 1; + }); + setChainConfig(newConfig); + setExpandedModelConfigs((prev) => { + const next = new Set(); + prev.forEach((v) => { + if (v < index) next.add(v); + else if (v > index) next.add(v - 1); + }); + return next; + }); + setExpandedAdvanced((prev) => { + const next = new Set(); + prev.forEach((v) => { + if (v < index) next.add(v); + else if (v > index) next.add(v - 1); + }); + return next; + }); + } + + function updateProvider( + index: number, + field: keyof APIChainItem, + value: string | number | boolean, + ) { + const newConfig = [...chainConfig]; + newConfig[index] = { ...newConfig[index], [field]: value }; + // If provider changes, reset model_configs + if (field === 'provider_uuid') { + newConfig[index].model_configs = []; + } + setChainConfig(newConfig); + } + + function moveProvider(index: number, direction: 'up' | 'down') { + if ( + (direction === 'up' && index === 0) || + (direction === 'down' && index === chainConfig.length - 1) + ) + return; + const newConfig = [...chainConfig]; + const target = direction === 'up' ? index - 1 : index + 1; + [newConfig[index], newConfig[target]] = [ + newConfig[target], + newConfig[index], + ]; + newConfig.forEach((c, i) => { + c.priority = i + 1; + }); + setChainConfig(newConfig); + } + + // ---- Model config CRUD ---- + + function addModelConfig(providerIndex: number) { + const newConfig = [...chainConfig]; + const existing = newConfig[providerIndex].model_configs ?? []; + const maxPriority = Math.max(...existing.map((m) => m.priority), 0); + newConfig[providerIndex] = { + ...newConfig[providerIndex], + model_configs: [ + ...existing, + { model_name: '', priority: maxPriority + 1 }, + ], + }; + setChainConfig(newConfig); + } + + function removeModelConfig(providerIndex: number, modelIndex: number) { + const newConfig = [...chainConfig]; + const existing = (newConfig[providerIndex].model_configs ?? []).filter( + (_, i) => i !== modelIndex, + ); + existing.forEach((m, i) => { + m.priority = i + 1; + }); + newConfig[providerIndex] = { + ...newConfig[providerIndex], + model_configs: existing, + }; + setChainConfig(newConfig); + } + + function updateModelConfig( + providerIndex: number, + modelIndex: number, + field: keyof APIChainModelConfig, + value: string | number | boolean, + ) { + const newConfig = [...chainConfig]; + const models = [...(newConfig[providerIndex].model_configs ?? [])]; + models[modelIndex] = { ...models[modelIndex], [field]: value }; + newConfig[providerIndex] = { + ...newConfig[providerIndex], + model_configs: models, + }; + setChainConfig(newConfig); + } + + function moveModelConfig( + providerIndex: number, + modelIndex: number, + direction: 'up' | 'down', + ) { + const models = [...(chainConfig[providerIndex].model_configs ?? [])]; + if ( + (direction === 'up' && modelIndex === 0) || + (direction === 'down' && modelIndex === models.length - 1) + ) + return; + const target = direction === 'up' ? modelIndex - 1 : modelIndex + 1; + [models[modelIndex], models[target]] = [models[target], models[modelIndex]]; + models.forEach((m, i) => { + m.priority = i + 1; + }); + const newConfig = [...chainConfig]; + newConfig[providerIndex] = { + ...newConfig[providerIndex], + model_configs: models, + }; + setChainConfig(newConfig); + } + + // ---- Submit ---- + + async function handleSubmit() { + if (!name.trim()) { + toast.error(t('apiChains.nameRequired')); + return; + } + if (chainConfig.length === 0) { + toast.error(t('apiChains.atLeastOneProvider')); + return; + } + for (const config of chainConfig) { + if (!config.provider_uuid) { + toast.error(t('apiChains.selectAllProviders')); + return; + } + for (const mc of config.model_configs ?? []) { + if (!mc.model_name) { + toast.error(t('apiChains.selectAllModels')); + return; + } + } + } + + setLoading(true); + try { + const data = { + name, + description: description || undefined, + chain_config: chainConfig, + health_check_enabled: healthCheckEnabled, + health_check_interval: healthCheckInterval, + }; + + if (chainId) { + await httpClient.updateAPIChain(chainId, data); + toast.success(t('apiChains.updateSuccess')); + } else { + await httpClient.createAPIChain(data); + toast.success(t('apiChains.createSuccess')); + } + onFormSubmit(); + } catch (err) { + toast.error( + t(chainId ? 'apiChains.updateError' : 'apiChains.createError') + + ': ' + + (err as Error).message, + ); + } finally { + setLoading(false); + } + } + + return ( +
+ {/* Name */} +
+ + setName(e.target.value)} + placeholder={t('apiChains.namePlaceholder')} + /> +
+ + {/* Description */} +
+ +