Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions ruoyi-fastapi-backend/common/router.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import importlib
import os
import sys
Expand Down Expand Up @@ -306,17 +307,8 @@ def _find_controller_files(self) -> list[str]:

:return: py文件路径列表
"""
controller_files = []
# 遍历所有目录,查找controller目录
for root, _dirs, files in os.walk(self.project_root):
# 检查当前目录是否为controller目录
if os.path.basename(root) == 'controller':
# 遍历controller目录下的所有py文件
for file in files:
if file.endswith('.py') and not file.startswith('__'):
file_path = os.path.join(root, file)
controller_files.append(file_path)
return controller_files
pattern = os.path.join(self.project_root, '*', 'controller', '[!_]*.py')
return sorted(glob.glob(pattern))

def _import_module_and_get_routers(self, controller_files: list[str]) -> list[tuple[str, APIRouter]]:
"""
Expand All @@ -333,9 +325,8 @@ def _import_module_and_get_routers(self, controller_files: list[str]) -> list[tu

# 动态导入模块
module = importlib.import_module(module_name)
# 遍历模块属性,寻找APIRouter和APIRouterPro实例
for attr_name in dir(module):
attr = getattr(module, attr_name)
# 直接遍历模块__dict__,只检查模块自身定义的属性
for attr_name, attr in module.__dict__.items():
# 对于APIRouterPro实例,只有当auto_register=True时才添加
if isinstance(attr, APIRouterPro):
if attr.auto_register:
Expand Down
2 changes: 1 addition & 1 deletion ruoyi-fastapi-backend/config/get_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __find_recent_workday(cls, day: int) -> int:
'password': RedisConfig.redis_password,
'db': RedisConfig.redis_database,
}
executors = {'default': AsyncIOExecutor(), 'processpool': ProcessPoolExecutor(5)}
job_defaults = {'coalesce': False, 'max_instance': 1}
scheduler = AsyncIOScheduler()

Expand Down Expand Up @@ -186,6 +185,7 @@ def _configure_scheduler(cls) -> None:
'sqlalchemy': SQLAlchemyJobStore(url=SYNC_SQLALCHEMY_DATABASE_URL, engine=cls._get_jobstore_engine()),
'redis': RedisJobStore(**redis_config),
}
executors = {'default': AsyncIOExecutor(), 'processpool': ProcessPoolExecutor(5)}
scheduler.configure(jobstores=job_stores, executors=executors, job_defaults=job_defaults)
cls._scheduler_configured = True

Expand Down
182 changes: 97 additions & 85 deletions ruoyi-fastapi-backend/utils/ai_util.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,115 @@
from agno.db.base import AsyncBaseDb
from agno.db.mysql import AsyncMySQLDb
from agno.db.postgres import AsyncPostgresDb
from agno.models.aimlapi import AIMLAPI
from agno.models.anthropic import Claude
from agno.models.base import Model
from agno.models.cerebras import Cerebras, CerebrasOpenAI
from agno.models.cohere import Cohere
from agno.models.cometapi import CometAPI
from agno.models.dashscope import DashScope
from agno.models.deepinfra import DeepInfra
from agno.models.deepseek import DeepSeek
from agno.models.fireworks import Fireworks
from agno.models.google import Gemini
from agno.models.groq import Groq
from agno.models.huggingface import HuggingFace
from agno.models.langdb import LangDB
from agno.models.litellm import LiteLLM, LiteLLMOpenAI
from agno.models.llama_cpp import LlamaCpp
from agno.models.lmstudio import LMStudio
from agno.models.meta import Llama
from agno.models.mistral import MistralChat
from agno.models.n1n import N1N
from agno.models.nebius import Nebius
from agno.models.nexus import Nexus
from agno.models.nvidia import Nvidia
from agno.models.ollama import Ollama
from agno.models.openai import OpenAIChat
from agno.models.openai.responses import OpenAIResponses
from agno.models.openrouter import OpenRouter
from agno.models.perplexity import Perplexity
from agno.models.portkey import Portkey
from agno.models.requesty import Requesty
from agno.models.sambanova import Sambanova
from agno.models.siliconflow import Siliconflow
from agno.models.together import Together
from agno.models.vercel import V0
from agno.models.vllm import VLLM
from agno.models.xai import xAI
from importlib import import_module
from typing import TYPE_CHECKING

from config.database import async_engine
from config.env import DataBaseConfig

provider_model_map: dict[str, type[Model]] = {
'AIMLAPI': AIMLAPI,
'Anthropic': Claude,
'Cerebras': Cerebras,
'CerebrasOpenAI': CerebrasOpenAI,
'Cohere': Cohere,
'CometAPI': CometAPI,
'DashScope': DashScope,
'DeepInfra': DeepInfra,
'DeepSeek': DeepSeek,
'Fireworks': Fireworks,
'Google': Gemini,
'Groq': Groq,
'HuggingFace': HuggingFace,
'LangDB': LangDB,
'LiteLLM': LiteLLM,
'LiteLLMOpenAI': LiteLLMOpenAI,
'LlamaCpp': LlamaCpp,
'LMStudio': LMStudio,
'Meta': Llama,
'Mistral': MistralChat,
'N1N': N1N,
'Nebius': Nebius,
'Nexus': Nexus,
'Nvidia': Nvidia,
'Ollama': Ollama,
'OpenAI': OpenAIChat,
'OpenAIResponses': OpenAIResponses,
'OpenRouter': OpenRouter,
'Perplexity': Perplexity,
'Portkey': Portkey,
'Requesty': Requesty,
'Sambanova': Sambanova,
'SiliconFlow': Siliconflow,
'Together': Together,
'Vercel': V0,
'VLLM': VLLM,
'xAI': xAI,
}
if TYPE_CHECKING:
from agno.db.base import AsyncBaseDb
from agno.models.base import Model

# 提供商名称 -> (模块路径, 类名) 的映射,延迟导入避免启动时加载所有AI SDK
_PROVIDER_REGISTRY: dict[str, tuple[str, str]] = {
'AIMLAPI': ('agno.models.aimlapi', 'AIMLAPI'),
'Anthropic': ('agno.models.anthropic', 'Claude'),
'Cerebras': ('agno.models.cerebras', 'Cerebras'),
'CerebrasOpenAI': ('agno.models.cerebras', 'CerebrasOpenAI'),
'Cohere': ('agno.models.cohere', 'Cohere'),
'CometAPI': ('agno.models.cometapi', 'CometAPI'),
'DashScope': ('agno.models.dashscope', 'DashScope'),
'DeepInfra': ('agno.models.deepinfra', 'DeepInfra'),
'DeepSeek': ('agno.models.deepseek', 'DeepSeek'),
'Fireworks': ('agno.models.fireworks', 'Fireworks'),
'Google': ('agno.models.google', 'Gemini'),
'Groq': ('agno.models.groq', 'Groq'),
'HuggingFace': ('agno.models.huggingface', 'HuggingFace'),
'LangDB': ('agno.models.langdb', 'LangDB'),
'LiteLLM': ('agno.models.litellm', 'LiteLLM'),
'LiteLLMOpenAI': ('agno.models.litellm', 'LiteLLMOpenAI'),
'LlamaCpp': ('agno.models.llama_cpp', 'LlamaCpp'),
'LMStudio': ('agno.models.lmstudio', 'LMStudio'),
'Meta': ('agno.models.meta', 'Llama'),
'Mistral': ('agno.models.mistral', 'MistralChat'),
'N1N': ('agno.models.n1n', 'N1N'),
'Nebius': ('agno.models.nebius', 'Nebius'),
'Nexus': ('agno.models.nexus', 'Nexus'),
'Nvidia': ('agno.models.nvidia', 'Nvidia'),
'Ollama': ('agno.models.ollama', 'Ollama'),
'OpenAI': ('agno.models.openai', 'OpenAIChat'),
'OpenAIResponses': ('agno.models.openai.responses', 'OpenAIResponses'),
'OpenRouter': ('agno.models.openrouter', 'OpenRouter'),
'Perplexity': ('agno.models.perplexity', 'Perplexity'),
'Portkey': ('agno.models.portkey', 'Portkey'),
'Requesty': ('agno.models.requesty', 'Requesty'),
'Sambanova': ('agno.models.sambanova', 'Sambanova'),
'SiliconFlow': ('agno.models.siliconflow', 'Siliconflow'),
'Together': ('agno.models.together', 'Together'),
'Vercel': ('agno.models.vercel', 'V0'),
'VLLM': ('agno.models.vllm', 'VLLM'),
'xAI': ('agno.models.xai', 'xAI'),
}

storage_engine_map: dict[str, type[AsyncBaseDb]] = {
'mysql': AsyncMySQLDb,
'postgresql': AsyncPostgresDb,
# 存储引擎名称 -> (模块路径, 类名) 的映射
_STORAGE_ENGINE_REGISTRY: dict[str, tuple[str, str]] = {
'mysql': ('agno.db.mysql', 'AsyncMySQLDb'),
'postgresql': ('agno.db.postgres', 'AsyncPostgresDb'),
}

# 已加载的提供商类缓存,避免重复import_module
_provider_class_cache: dict[str, 'type[Model]'] = {}
_storage_class_cache: dict[str, 'type[AsyncBaseDb]'] = {}


def _resolve_provider_class(provider: str) -> 'type[Model] | None':
"""
按需加载并缓存提供商模型类

:param provider: 提供商名称
:return: 模型类,未找到返回None
"""
if provider in _provider_class_cache:
return _provider_class_cache[provider]
entry = _PROVIDER_REGISTRY.get(provider)
if entry is None:
return None
module_path, class_name = entry
cls = getattr(import_module(module_path), class_name)
_provider_class_cache[provider] = cls
return cls


def _resolve_storage_class(db_type: str) -> 'type[AsyncBaseDb]':
"""
按需加载并缓存存储引擎类

:param db_type: 数据库类型
:return: 存储引擎类
"""
if db_type in _storage_class_cache:
return _storage_class_cache[db_type]
entry = _STORAGE_ENGINE_REGISTRY.get(db_type)
if entry is None:
# 默认使用MySQL
entry = _STORAGE_ENGINE_REGISTRY['mysql']
module_path, class_name = entry
cls = getattr(import_module(module_path), class_name)
_storage_class_cache[db_type] = cls
return cls


class AiUtil:
"""
AI工具类
"""

@classmethod
def get_storage_engine(cls) -> AsyncBaseDb:
def get_storage_engine(cls) -> 'AsyncBaseDb':
"""
获取存储引擎实例

:return: 存储引擎实例
"""
storage_engine_class = storage_engine_map.get(DataBaseConfig.db_type, AsyncMySQLDb)
storage_engine_class = _resolve_storage_class(DataBaseConfig.db_type)

return storage_engine_class(
db_engine=async_engine,
Expand Down Expand Up @@ -128,7 +137,7 @@ def get_model_from_factory(
temperature: float | None = None,
max_tokens: int | None = None,
**kwargs,
) -> Model:
) -> 'Model':
"""
从工厂获取模型实例

Expand All @@ -155,6 +164,9 @@ def get_model_from_factory(
params['host'] = base_url
if provider == 'DashScope' and not base_url:
params['base_url'] = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
model_class = provider_model_map.get(provider, OpenAIChat)
model_class = _resolve_provider_class(provider)
if model_class is None:
# 未知提供商,回退到OpenAI
model_class = _resolve_provider_class('OpenAI')

return model_class(**params)
18 changes: 12 additions & 6 deletions ruoyi-fastapi-backend/utils/server_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ class IPUtil:
IP工具类
"""

_PREFERRED_DNS_HOSTS: tuple[str, str] = ('223.5.5.5', '8.8.8.8')
_DNS_CONNECT_TIMEOUT = 1

@classmethod
def get_local_ip(cls) -> str:
"""
Expand Down Expand Up @@ -472,12 +475,15 @@ def get_network_ips(cls) -> list[str]:

# 优先显示首选出站IP
preferred_ip = None
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
preferred_ip = s.getsockname()[0]
except Exception:
pass
for dns_host in cls._PREFERRED_DNS_HOSTS:
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.settimeout(cls._DNS_CONNECT_TIMEOUT)
s.connect((dns_host, 80))
preferred_ip = s.getsockname()[0]
break
except Exception:
continue

if preferred_ip:
if preferred_ip in network_ips:
Expand Down