diff --git a/.gitignore b/.gitignore index 722a4bc..158b840 100644 --- a/.gitignore +++ b/.gitignore @@ -1,292 +1,63 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[codz] -*$py.class +# Environment variables +.env +.env.* +!.env.example +*.backup -# 忽略所有以 _cache 结尾的目录 -*_cache/ +# Logs +*.log +logs/ -# 忽略所有以 _cache 结尾的文件 -*_cache +# Outputs +outputs/ +dataflow_agent/ -# C extensions +# Python +__pycache__/ +*.py[cod] +*$py.class *.so - -dataflow_agent/tmps/ -dataflow_cache/ -.dfavenv -.dfavenv/ -dataflow_agent/toolkits/multimodaltool/models/ -dataflow_agent/toolkits/multimodaltool/onnx/ - -# Distribution / packaging .Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -/lib/ -/lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cove -*.py.cover -.hypothesis/ -.pytest_cache/ -cover/ -tests/ -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# UV -# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -#uv.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock -#poetry.toml - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. -# https://pdm-project.org/en/latest/usage/project/#working-with-version-control -#pdm.lock -#pdm.toml -.pdm-python -.pdm-build/ - -# pixi -# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. -#pixi.lock -# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one -# in the .venv directory. It is recommended not to include this directory in version control. -.pixi - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.envrc -.venv -env/ venv/ +.venv/ +env/ ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy +*.egg-info/ +.pytest_cache/ .mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -# Abstra -# Abstra is an AI-powered process automation framework. -# Ignore directories containing user credentials, local state, and settings. -# Learn more at https://abstra.io/docs -.abstra/ - -# Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore -# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, -# you could uncomment the following to ignore the entire vscode folder -# .vscode/ - -# Ruff stuff: -.ruff_cache/ - -# PyPI configuration file -.pypirc - -# Cursor -# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to -# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data -# refer to https://docs.cursor.com/context/ignore-files -.cursorignore -.cursorindexingignore - -# Claude Code -CLAUDE.md - -# Marimo -marimo/_static/ -marimo/_lsp/ -__marimo__/ +.coverage +htmlcov/ -# Frontend / Node.js +# Node.js node_modules/ -frontend-workflow/node_modules/ -frontend-workflow/dist/ -frontend-workflow/build/ -frontend-workflow/.vite/ -*.local - -# Logs +dist/ +build/ +.next/ npm-debug.log* yarn-debug.log* yarn-error.log* -pnpm-debug.log* -lerna-debug.log* -dataflow_agent.log* -dataflow_agent.log.* - -# Editor directories and files -.DS_Store -*.suo -*.ntvs* -*.njsproj -*.sln -*.sw? - -# TypeScript cache -*.tsbuildinfo -outputs/ -fastapi_app/invite_codes.txt -models/* -outputs/* -tmps/* -data/* -sam_b.pt -# fastapi -fastapi_app/outputs/ -nohup.out +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ -invite_codes.txt - -.vscode - -logs/ -conf/ - -logs -dev-debug.log -# Dependency directories -# Environment variables -.idea -# OS specific - -# Task files -# tasks.json -# tasks/ - -# Taskmaster AI (local task management) -.taskmaster/ - -script/optimize_readme_assets.py +# OS +.DS_Store +Thumbs.db -# Personal notes -iphone.md +# Database +*.db +*.sqlite +*.sqlite3 -.claude +# Temporary files +*.tmp +*.temp +*.bak -rebuttal_sessions/ +# Static uploads +static/uploads/ +.claude/ diff --git a/README.md b/README.md index 82a55fb..799cea0 100644 --- a/README.md +++ b/README.md @@ -1,226 +1,386 @@
-OpenNotebook Logo +OpenNotebookLM # OpenNotebookLM [![Python](https://img.shields.io/badge/Python-3.10+-3776AB?style=flat-square&logo=python&logoColor=white)](https://www.python.org/) +[![Node](https://img.shields.io/badge/Node-18+-339933?style=flat-square&logo=node.js&logoColor=white)](https://nodejs.org/) [![License](https://img.shields.io/badge/License-Apache_2.0-2F80ED?style=flat-square&logo=apache&logoColor=white)](LICENSE) -中文 | [English](README_EN.md) +English | [中文](README_ZH.md) -✨ **NotebookLM 风格的知识库工作流平台:上传文档、智能问答、一键生成 PPT / 思维导图 / 播客 / DrawIO 图表** ✨ - -| 📚 **知识库管理**  |  💬 **智能问答**  |  🎨 **多模态生成**  |  🔍 **语义检索** | - -
- - - Quickstart - - - Docs - - - Contributing - - -
-
+**Open-source NotebookLM alternative** — Upload documents, chat with sources, generate PPTs / mind maps / podcasts / DrawIO diagrams / flashcards / quizzes / deep research reports in one click
--- -## 📑 目录 +## 📅 Changelog -- [✨ 核心功能](#-核心功能) -- [📸 展示](#-展示) -- [🚀 快速开始](#-快速开始) -- [📂 项目结构](#-项目结构) -- [🤝 参与贡献](#-参与贡献) +- **2026.03.11** — Code refactoring: strict layered architecture; integrated local TTS model ([Qwen3-TTS](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice)); added source-based note QA editing (Notion AI style); UI improvements; simplified configuration structure +- **2026.03.08** — Added user management: Supabase email + OTP authentication, multi-user data isolation, email-based user directories; cleaned up deprecated scripts +- **2026.02.27** — Integrated [Qwen-DeepResearch](https://github.com/Alibaba-NLP/DeepResearch) deep research module; PPT generation now supports Nano Banana 2 image model +- **2026.02.13** — Initial release --- -## ✨ 核心功能 +## 📸 Screenshots -> 以「笔记本 + 知识库」为核心,基于 DataFlow-Agent 工作流引擎,从上传的文档/论文出发,支持智能问答与多种一键生成能力。 +
+Dashboard +

Dashboard — Notebook management

+
-- **📚 知识库管理**:文件上传、列表查看、多选源文档,支持 PDF 等格式。 -- **💬 智能问答**:基于选中文档的上下文进行问答,对话历史本地持久化。 -- **🎨 PPT 生成**:从知识库内容或论文生成可编辑演示文稿(对接 Paper2PPT 工作流)。 -- **🧠 思维导图**:基于选中文档生成 Mermaid 思维导图,支持预览与导出。 -- **🎙️ 知识播客**:将知识库内容转为播客脚本与讲解素材。 -- **🎬 视频讲解**:生成视频脚本与讲解内容。 -- **🧩 Paper2Drawio**:从论文/文本或图片生成可编辑 DrawIO 图表,支持内嵌编辑与导出。 -- **🔍 语义检索**:基于嵌入的语义检索,支持 Top-K 与多模型选择。 +
+Notebook workspace +

Notebook workspace — Knowledge base + Smart QA + One-click generation

+
---- +
+Generation panel +

Generation panel — Multiple output formats

+
-## 📸 展示 +
+Chat and knowledge base +

Chat and knowledge base details

+
-### 首页 +
+PPT generation +

PPT generation

+
+Mind map +

Mind map

+
-首页预览 +
+DrawIO diagram +

DrawIO diagram — Inline editor

+
+
+Knowledge podcast +

Knowledge podcast

-### 二级界面(知识库与问答) +
+Flashcards +

Flashcard study

+
+Quiz +

Quiz

+
-二级界面预览 +
+Web search +

Web search to import sources

+
+
+Deep research report +

Deep research report generation

-### PPT 生成 +--- -
+## ✨ Core Features + +| Feature | Description | +|---------|-------------| +| 📚 **Knowledge Base** | Upload PDFs, paste URLs/text, import from web search — aggregate multiple sources into a notebook | +| 🔐 **User Management** | Supabase email + OTP authentication, multi-user data isolation; works without login when unconfigured | +| 💬 **Smart QA** | RAG-based Q&A grounded in selected documents, with persistent chat history | +| 🎨 **PPT Generation** | One-click editable slide decks from knowledge base content | +| 🧠 **Mind Maps** | Generate Mermaid mind maps with preview and export | +| 🎙️ **Knowledge Podcast** | Turn knowledge base content into podcast scripts and narration assets | +| 🧩 **DrawIO Diagrams** | Generate editable DrawIO diagrams from text or images, with inline editor | +| 🃏 **Flashcards** | Auto-generate study flashcards from knowledge base content | +| 📝 **Quizzes** | Auto-generate multiple-choice questions with scoring | +| 🔍 **Web Search** | Supports Serper / SerpAPI / Google CSE / Brave / Bocha search providers | +| 📊 **Deep Research Reports** | Web search + LLM synthesis to produce structured research reports | +| 🔗 **Semantic Search** | Local embedding-based vector retrieval with configurable Top-K and models | -PPT 生成 +--- -
+## 🚀 Quick Start -### 思维导图 +### 1. Clone & Install -
+```bash +git clone https://github.com/OpenDCAI/opennotebookLM.git +cd opennotebookLM -思维导图 +# Create virtual environment (Conda recommended) +conda create -n opennotebook python=3.11 -y +conda activate opennotebook -
+# Install Python dependencies +pip install -r requirements-base.txt +``` -### DrawIO 图表 +### 2. Configure API Keys -
+```bash +cp fastapi_app/.env.example fastapi_app/.env +``` -DrawIO 图表 +Edit `fastapi_app/.env` with at least the following: -
+#### LLM API (Required) ---- +The project calls LLMs via an OpenAI-compatible API. By default it uses [APIyi](https://www.apiyi.com) as a relay service (supports GPT / Claude / Gemini and more). -## 🚀 快速开始 +```env +# LLM API endpoint (OpenAI-compatible format) +DEFAULT_LLM_API_URL=https://api.apiyi.com/v1 -### 环境要求 +# Your API key (obtain from APIyi or another LLM provider) +# Can also be configured dynamically in the frontend settings panel +``` -![Python](https://img.shields.io/badge/Python-3.10+-3776AB?style=flat-square&logo=python&logoColor=white) -![Node](https://img.shields.io/badge/Node-18+-339933?style=flat-square&logo=node.js&logoColor=white) +> You can use any OpenAI-compatible API service (OpenAI official, Azure OpenAI, local Ollama, etc.) — just change `DEFAULT_LLM_API_URL`. -- **Python**: 3.10+ -- **Node.js**: 18+(前端构建) -- **操作系统**: Linux(推荐)/ Windows / macOS +#### Search API (Required for web search features) -### 后端安装与启动 +Web search and deep research report features require a search engine API. Any one of the following providers will work: -```bash -# 1. 克隆仓库 -git clone -cd opennoteboolLM +| Provider | Configuration | Sign up | +|----------|--------------|---------| +| **Serper** (recommended) | Env variable `SERPER_API_KEY` | [serper.dev](https://serper.dev) | +| **SerpAPI** | Pass `search_api_key` from frontend | [serpapi.com](https://serpapi.com) | +| **Google CSE** | Pass `search_api_key` + `google_cse_id` from frontend | [programmablesearchengine.google.com](https://programmablesearchengine.google.com) | +| **Brave Search** | Pass `search_api_key` from frontend | [brave.com/search/api](https://brave.com/search/api) | +| **Bocha** | Pass `search_api_key` from frontend | [open.bochaai.com](https://open.bochaai.com) | -# 2. 创建并激活虚拟环境(推荐 Conda) -conda create -n opennotebook python=3.11 -y -conda activate opennotebook +Serper is configured via a backend environment variable. Other providers can be set in the frontend settings panel. -# 3. 安装依赖 -pip install -r requirements-base.txt -pip install -e . +```env +# Serper (Google search), recommended +SERPER_API_KEY=your_serper_api_key +``` -# 4. 配置环境变量(可选) -cp fastapi_app/.env.example fastapi_app/.env -# 编辑 fastapi_app/.env,配置 DF_API_KEY、DF_API_URL、Supabase 等 +#### Supabase (Optional — User Management) -# 5. 启动后端 -cd fastapi_app -uvicorn main:app --host 0.0.0.0 --port 8000 +For multi-user authentication and data isolation. **If not configured or left empty, the system automatically enters trial mode** (no login required, single local user, all core features work normally). + +When configured: email + password sign-up/login, OTP email verification, per-user data isolation (separate directories per user). + +```env +# If you don't need multi-user features, you can delete or leave empty +SUPABASE_URL=https://your-project-id.supabase.co +SUPABASE_ANON_KEY=your_supabase_anon_key +SUPABASE_SERVICE_ROLE_KEY=your_supabase_service_role_key ``` -后端健康检查:,API 文档:。 +#### TTS Voice Synthesis (Optional — Podcast Feature) + +Podcast generation supports local TTS models. When enabled, it will automatically download the [Qwen3-TTS](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice) model (~3.4GB). -### 前端安装与启动 +```env +# Enable local TTS (0=disabled, 1=enabled) +USE_LOCAL_TTS=1 -提供中英双前端,任选其一即可。 +# TTS engine: qwen (recommended) or firered +TTS_ENGINE=qwen + +# Model idle auto-unload timeout (seconds, default 300 = 5 minutes) +TTS_IDLE_TIMEOUT=300 +``` -**英文前端(frontend_en,NotebookLM 风格)** +> **Tip**: If you don't need podcast features, set `USE_LOCAL_TTS=0` or delete this config to save disk space. + +### 3. Start Backend ```bash -cd frontend_en -npm install -cp .env.example .env # 可选,配置 VITE_API_KEY、VITE_DEFAULT_LLM_API_URL、Supabase 等 -npm run dev +uvicorn fastapi_app.main:app --host 0.0.0.0 --port 8213 --reload ``` -**中文前端(frontend_zh)** +On startup, the backend automatically launches a local embedding service (Octen-Embedding-0.6B on port `26210` by default). The model is downloaded on first run. To disable local embedding, set `USE_LOCAL_EMBEDDING=0`. + +- Health check: http://localhost:8213/health +- API docs: http://localhost:8213/docs + +### 4. Start Frontend + +Both English and Chinese frontends are provided — pick either: ```bash -cd frontend_zh -npm install -npm run dev +# English frontend +cd frontend_en && npm install && npm run dev + +# Chinese frontend +cd frontend_zh && npm install && npm run dev +``` + +Open http://localhost:3000 (or the port shown in the terminal). + +> `npm run dev` uses each frontend's `vite.config.ts`, and the current default frontend port is `3000`. +> If you use the repository's `scripts/start.sh`, it starts the **Chinese frontend** on port `3001`, the backend on `8213`, and the cpolar tunnel together. + +> The LLM API URL and API key can be changed dynamically in the settings panel (top-right corner) without restarting. + +#### Frontend Configuration (Optional) + +**For local deployment** (frontend and backend on the same machine): No configuration needed. The default setup works out of the box. + +**For public deployment** (via cpolar/ngrok tunneling): + +The frontend has built-in smart detection: +- When `.env` is set to `localhost` but accessed from a public URL, it automatically uses relative paths (current domain) +- In dev mode, Vite proxies `/api` and `/outputs` to the local backend at `http://localhost:8213` +- **Recommended**: Use nginx reverse proxy to unify frontend and backend under the same domain, no extra configuration needed + +> **Note**: The ports shown here, such as `3000`, `3001`, `8080`, and `8213`, are example ports only. In a real deployment, replace them with the actual ports used by your frontend, backend, and proxy services. +> For personal testing or lightweight usage, `scripts/start.sh + Vite proxy + cpolar` is sufficient; for more stable public access or larger-scale deployments, nginx reverse proxy is still the recommended approach. +> In the current repository, `scripts/start.sh` uses `CPOLAR_TUNNEL_NAME=opennotebook` and prints the configured `CPOLAR_PUBLIC_URL`. If you change your reserved cpolar tunnel, update both variables in the script as well. + +Create `frontend_zh/.env` (or `frontend_en/.env`): + +```env +# Backend API base URL (for local development) +VITE_API_BASE_URL=http://localhost:8213 ``` -访问 **http://localhost:3000**(或终端提示的端口,如 3001)。 +**Deployment comparison:** + +| Deployment Type | Configuration | Description | +|----------------|---------------|-------------| +| **Local development** | `VITE_API_BASE_URL=http://localhost:8213` | Frontend and backend both run locally | +| **Using `scripts/start.sh`** | `VITE_API_BASE_URL=http://localhost:8213` | The current script starts the Chinese frontend on `3001`, backend on `8213`, and exposes the frontend through a named cpolar tunnel | +| **Public deployment (recommended)** | `VITE_API_BASE_URL=http://localhost:8213` | Use nginx reverse proxy for unified domain, smart detection auto-switches to relative paths | +| **Public deployment (separated)** | `VITE_API_BASE_URL=https://backend-xxx.cpolar.io` | Frontend and backend use different domains, requires manual backend URL configuration | + +**Recommended: Use nginx reverse proxy for unified domain** + +Create `nginx.conf`: + +```nginx +server { + listen 8080; + + # Frontend + location / { + proxy_pass http://localhost:3000; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + } + + # Backend API + location /api/ { + proxy_pass http://localhost:8213/api/; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + } + + # Backend output files + location /outputs/ { + proxy_pass http://localhost:8213/outputs/; + } +} +``` -### 环境变量说明 +If you are not running `npm run dev` directly and instead use the current repository's `scripts/start.sh`, change the frontend upstream above from `http://localhost:3000` to `http://localhost:3001`. -- **后端 `fastapi_app/.env`** - - `DF_API_KEY`、`DF_API_URL`:LLM 调用。 - - `SUPABASE_URL`、`SUPABASE_ANON_KEY` 等:可选,用于用户认证与云存储。 -- **前端 `frontend_en/.env`** - - `VITE_API_KEY`:请求后端 API 的密钥(需与后端一致)。 - - `VITE_DEFAULT_LLM_API_URL`:默认 LLM 提供商地址。 - - `VITE_SUPABASE_*`:可选,与后端 Supabase 配置对应。 +Then expose nginx port via cpolar: +```bash +cpolar http 8080 +``` -不配置 Supabase 时,前端可使用本地模拟用户进行开发与体验。 +This way frontend and backend share the same domain, smart detection will automatically use relative paths without configuration changes. Replace the example ports with the real ports used in your environment. + +> **Note**: After changing `.env`, rebuild the frontend (`npm run build`) or restart dev server (`npm run dev`). --- -## 📂 项目结构 +## 📂 Project Structure ``` -opennoteboolLM/ -├── dataflow_agent/ # 工作流引擎 -│ ├── agentroles/ # Agent 角色定义 -│ ├── workflow/ # 工作流(Paper2PPT、PDF2PPT、Image2Drawio、KB 等) -│ ├── promptstemplates/ # 提示模板 -│ └── toolkits/ # 工具集 -├── fastapi_app/ # 后端 API -│ ├── routers/ # 知识库、文件、Paper2Drawio、Paper2PPT 等 -│ └── workflow_adapters/ # 工作流适配 -├── frontend_en/ # 英文前端(NotebookLM 风格) -├── frontend_zh/ # 中文前端 -├── database/ # 数据库脚本 -├── docs/ # 文档 -├── script/ # CLI 与脚本 -├── static/ # 静态资源与 README 配图 -└── outputs/ # 生成文件输出目录 +opennotebookLM/ +├── fastapi_app/ # Backend API (FastAPI) +│ ├── routers/ # Routes: KB, auth, Paper2PPT, Paper2Drawio, etc. +│ ├── services/ # Business logic: search, flashcards, quizzes, etc. +│ ├── config/ # Configuration & environment variables +│ ├── dependencies/ # Dependency injection (auth, Supabase client) +│ ├── middleware/ # Middleware (API key validation) +│ └── workflow_adapters/ # Workflow adapter layer +├── workflow_engine/ # Workflow engine (DataFlow-Agent) +│ ├── agentroles/ # Agent role definitions +│ ├── workflow/ # Workflows (Paper2PPT, PDF2PPT, Image2Drawio, etc.) +│ ├── promptstemplates/ # Prompt templates +│ └── toolkits/ # Toolkits (search, parsing, etc.) +├── frontend_en/ # English frontend (React + Vite + Tailwind) +├── frontend_zh/ # Chinese frontend +├── database/ # Database migration scripts +├── docs/ # Documentation +├── script/ # Utility scripts (DB init, etc.) +├── static/ # Static assets +└── outputs/ # Generated file output directory (isolated by user email) ``` --- -## 🤝 参与贡献 +## ⚙️ Model Configuration -欢迎提交 Issue、Pull Request 以及文档改进。 +The project uses a three-layer model configuration system, from coarse to fine-grained: -[![Issues](https://img.shields.io/badge/Issues-Submit_Bug-red?style=for-the-badge&logo=github)](https://github.com/your-org/opennoteboolLM/issues) -[![PR](https://img.shields.io/badge/PR-Submit_Code-green?style=for-the-badge&logo=github)](https://github.com/your-org/opennoteboolLM/pulls) +1. **Base model layer** — Define available model names (`MODEL_GPT_4O`, `MODEL_CLAUDE_HAIKU`, etc.) +2. **Workflow layer** — Set default models per workflow (`PAPER2PPT_DEFAULT_MODEL`, etc.) +3. **Role layer** — Fine-grained control over each role within a workflow (`PAPER2PPT_OUTLINE_MODEL`, etc.) -详见 [贡献指南](docs/contributing.md)。 +See `fastapi_app/.env.example` for the full configuration reference. --- -## 📄 许可证 +## 🗺️ Roadmap + +- [x] Knowledge base management (upload files / paste URLs / text) +- [x] RAG smart Q&A +- [x] PPT generation +- [x] Mind map generation +- [x] DrawIO diagram generation +- [x] Knowledge podcast generation +- [x] Flashcards & quizzes +- [x] Web search source import +- [x] Deep research reports +- [x] Local embedding vector retrieval +- [x] User management (Supabase email auth + multi-user isolation) +- [ ] Video generation (in progress) +- [ ] Video source import (in progress) +- [ ] Audio source import (in progress) -本项目采用 [Apache License 2.0](LICENSE)。 +--- + +## 🤝 Contributing + +Issues and pull requests are welcome. See [Contributing Guide](docs/contributing.md). --- -**功能卡片功能基于:[OpenDCAI/Paper2Any](https://github.com/OpenDCAI/Paper2Any)** +## 📄 License + +[Apache License 2.0](LICENSE) + +Generation features are built on [OpenDCAI/Paper2Any](https://github.com/OpenDCAI/Paper2Any). ---
-**若本项目对你有帮助,欢迎 ⭐ Star** +**If this project helps you, please give it a ⭐ Star** + +
+ +--- +## 💬 Community + +
+WeChat Group +

Scan to join our WeChat group

diff --git a/README_EN.md b/README_EN.md deleted file mode 100644 index 8838562..0000000 --- a/README_EN.md +++ /dev/null @@ -1,224 +0,0 @@ -
- -OpenNotebook Logo - -# OpenNotebookLM - -[![Python](https://img.shields.io/badge/Python-3.10+-3776AB?style=flat-square&logo=python&logoColor=white)](https://www.python.org/) -[![License](https://img.shields.io/badge/License-Apache_2.0-2F80ED?style=flat-square&logo=apache&logoColor=white)](LICENSE) - -[中文](README.md) | English - -✨ **A NotebookLM-style knowledge-base workflow: upload documents, chat with sources, and generate PPTs, mind maps, podcasts, and DrawIO diagrams in one click** ✨ - -| 📚 **Knowledge Base**  |  💬 **Smart QA**  |  🎨 **Multimodal Generation**  |  🔍 **Semantic Search** | - -
- - - Quickstart - - - Docs - - - Contributing - - -
-
- ---- - -## 📑 Table of Contents - -- [✨ Core Features](#-core-features) -- [📸 Showcase](#-showcase) -- [🚀 Quick Start](#-quick-start) -- [📂 Project Structure](#-project-structure) -- [🤝 Contributing](#-contributing) - ---- - -## ✨ Core Features - -> Built around **notebooks + knowledge base** on the DataFlow-Agent workflow engine: upload documents or papers, then use smart QA and one-click generation for multiple output types. - -- **📚 Knowledge Base**: Upload files, browse and select sources (e.g. PDFs). -- **💬 Smart QA**: Ask questions grounded in selected documents; chat history is persisted locally. -- **🎨 PPT Generation**: Generate editable slide decks from your knowledge base or papers (Paper2PPT workflow). -- **🧠 Mind Maps**: Generate Mermaid mind maps from selected sources, with preview and export. -- **🎙️ Knowledge Podcast**: Turn knowledge-base content into podcast scripts and narration assets. -- **🎬 Video Narration**: Generate video scripts and narration content. -- **🧩 Paper2Drawio**: Generate editable DrawIO diagrams from papers, text, or images; inline edit and export. -- **🔍 Semantic Search**: Embedding-based semantic retrieval with configurable Top-K and models. - ---- - -## 📸 Showcase - -### Home - -
- -Home - -
- -### Notebook View (Sources & Chat) - -
- -Notebook view - -
- -### PPT Generation - -
- -PPT generation - -
- -### Mind Map - -
- -Mind map - -
- -### DrawIO Diagrams - -
- -DrawIO - -
- ---- - -## 🚀 Quick Start - -### Requirements - -![Python](https://img.shields.io/badge/Python-3.10+-3776AB?style=flat-square&logo=python&logoColor=white) -![Node](https://img.shields.io/badge/Node-18+-339933?style=flat-square&logo=node.js&logoColor=white) - -- **Python**: 3.10+ -- **Node.js**: 18+ (for frontend build) -- **OS**: Linux (recommended) / Windows / macOS - -### Backend - -```bash -# 1. Clone -git clone -cd opennoteboolLM - -# 2. Create and activate environment (Conda recommended) -conda create -n opennotebook python=3.11 -y -conda activate opennotebook - -# 3. Install dependencies -pip install -r requirements-base.txt -pip install -e . - -# 4. Environment variables (optional) -cp fastapi_app/.env.example fastapi_app/.env -# Edit fastapi_app/.env: DF_API_KEY, DF_API_URL, Supabase, etc. - -# 5. Start backend -cd fastapi_app -uvicorn main:app --host 0.0.0.0 --port 8000 -``` - -Health: · API docs: - -### Frontend - -Both English and Chinese frontends are provided; use either. - -**English (frontend_en, NotebookLM-style)** - -```bash -cd frontend_en -npm install -cp .env.example .env # Optional: VITE_API_KEY, VITE_DEFAULT_LLM_API_URL, Supabase, etc. -npm run dev -``` - -**Chinese (frontend_zh)** - -```bash -cd frontend_zh -npm install -npm run dev -``` - -Open **http://localhost:3000** (or the port shown in the terminal, e.g. 3001). - -### Environment Variables - -- **Backend `fastapi_app/.env`** - - `DF_API_KEY`, `DF_API_URL`: LLM API. - - `SUPABASE_URL`, `SUPABASE_ANON_KEY`, etc.: optional, for auth and cloud storage. -- **Frontend `frontend_en/.env`** - - `VITE_API_KEY`: API key for backend requests (must match backend). - - `VITE_DEFAULT_LLM_API_URL`: default LLM provider URL. - - `VITE_SUPABASE_*`: optional, align with backend Supabase if used. - -Without Supabase, the frontend can use a local mock user for development and try-out. - ---- - -## 📂 Project Structure - -``` -opennoteboolLM/ -├── dataflow_agent/ # Workflow engine -│ ├── agentroles/ # Agent definitions -│ ├── workflow/ # Workflows (Paper2PPT, PDF2PPT, Image2Drawio, KB, etc.) -│ ├── promptstemplates/ # Prompt templates -│ └── toolkits/ # Toolkits -├── fastapi_app/ # Backend API -│ ├── routers/ # KB, files, Paper2Drawio, Paper2PPT, etc. -│ └── workflow_adapters/ # Workflow adapters -├── frontend_en/ # English frontend (NotebookLM-style) -├── frontend_zh/ # Chinese frontend -├── database/ # DB scripts -├── docs/ # Documentation -├── script/ # CLI and scripts -├── static/ # Assets and README images -└── outputs/ # Generated outputs -``` - ---- - -## 🤝 Contributing - -Issues, pull requests, and documentation improvements are welcome. - -[![Issues](https://img.shields.io/badge/Issues-Submit_Bug-red?style=for-the-badge&logo=github)](https://github.com/your-org/opennoteboolLM/issues) -[![PR](https://img.shields.io/badge/PR-Submit_Code-green?style=for-the-badge&logo=github)](https://github.com/your-org/opennoteboolLM/pulls) - -See [Contributing](docs/contributing.md). - ---- - -## 📄 License - -This project is under [Apache License 2.0](LICENSE). - ---- - -**Feature cards are based on: [OpenDCAI/Paper2Any](https://github.com/OpenDCAI/Paper2Any)** - ---- - -
- -**If this project helps you, please give it a ⭐ Star** - -
diff --git a/README_ZH.md b/README_ZH.md new file mode 100644 index 0000000..96022d1 --- /dev/null +++ b/README_ZH.md @@ -0,0 +1,386 @@ +
+ +OpenNotebookLM + +# OpenNotebookLM + +[![Python](https://img.shields.io/badge/Python-3.10+-3776AB?style=flat-square&logo=python&logoColor=white)](https://www.python.org/) +[![Node](https://img.shields.io/badge/Node-18+-339933?style=flat-square&logo=node.js&logoColor=white)](https://nodejs.org/) +[![License](https://img.shields.io/badge/License-Apache_2.0-2F80ED?style=flat-square&logo=apache&logoColor=white)](LICENSE) + +中文 | [English](README.md) + +**开源的 NotebookLM 替代方案** — 上传文档,智能问答,一键生成 PPT / 思维导图 / 播客 / DrawIO 图表 / 闪卡 / 测试题 / 深度研究报告 + +
+ +--- + +## 📅 更新日志 + +- **2026.03.11** — 代码重构:实行严格的功能分层架构;集成本地 TTS 模型([Qwen3-TTS](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice));新增基于来源的笔记 QA 问答编辑功能(Notion AI 风格);UI 优化;简化配置文件结构 +- **2026.03.08** — 新增用户管理系统:Supabase 邮箱 + OTP 认证登录,多用户数据隔离,用户目录以邮箱命名;清理废弃脚本 +- **2026.02.27** — 迁移集成 [Qwen-DeepResearch](https://github.com/Alibaba-NLP/DeepResearch) 深度研究模块;PPT 生成支持 Nano Banana 2 生图模型 +- **2026.02.13** — 项目发布 + +--- + +## 📸 界面预览 + +
+首页仪表盘 +

首页仪表盘 — 笔记本管理

+
+ +
+笔记本工作区 +

笔记本工作区 — 知识库 + 智能问答 + 一键生成

+
+ +
+生成面板 +

生成面板 — 多种输出格式

+
+ +
+对话与知识库 +

对话与知识库详情

+
+ +
+PPT 生成 +

PPT 生成

+
+ +
+思维导图 +

思维导图

+
+ +
+DrawIO 图表 +

DrawIO 图表 — 内嵌编辑器

+
+ +
+知识播客 +

知识播客

+
+ +
+闪卡 +

闪卡学习

+
+ +
+测试题 +

测试题

+
+ +
+联网搜索 +

联网搜索引入来源

+
+ +
+深度研究报告 +

深度研究报告生成

+
+ +--- + +## ✨ 核心功能 + +| 功能 | 说明 | +|------|------| +| 📚 **知识库管理** | 上传 PDF 等文档、粘贴网址/文本、联网搜索引入,多源聚合到笔记本 | +| 🔐 **用户管理** | 基于 Supabase 的邮箱注册/登录 + OTP 验证,多用户数据隔离;不配置时可无登录体验全部功能 | +| 💬 **智能问答** | 基于选中文档的 RAG 问答,对话历史持久化 | +| 🎨 **PPT 生成** | 从知识库内容一键生成可编辑演示文稿 | +| 🧠 **思维导图** | 生成 Mermaid 思维导图,支持预览与导出 | +| 🎙️ **知识播客** | 将知识库内容转为播客脚本与讲解素材 | +| 🧩 **DrawIO 图表** | 从文本或图片生成可编辑 DrawIO 图表,内嵌编辑器 | +| 🃏 **闪卡** | 基于知识库内容自动生成学习闪卡 | +| 📝 **测试题** | 自动生成选择题,支持作答与评分 | +| 🔍 **联网搜索** | 支持 Serper / SerpAPI / Google CSE / Brave / 博查等多种搜索引擎 | +| 📊 **深度研究报告** | 联网搜索 + LLM 综合分析,生成结构化研究报告 | +| 🔗 **语义检索** | 本地 Embedding 向量检索,支持 Top-K 与多模型 | + +--- + +## 🚀 快速开始 + +### 1. 克隆与安装 + +```bash +git clone https://github.com/OpenDCAI/opennotebookLM.git +cd opennotebookLM + +# 创建虚拟环境(推荐 Conda) +conda create -n opennotebook python=3.11 -y +conda activate opennotebook + +# 安装 Python 依赖 +pip install -r requirements-base.txt +``` + +### 2. 配置 API 密钥 + +```bash +cp fastapi_app/.env.example fastapi_app/.env +``` + +编辑 `fastapi_app/.env`,至少配置以下内容: + +#### LLM API(必需) + +项目通过 OpenAI 兼容接口调用大模型,默认使用 [APIyi](https://www.apiyi.com) 作为中转服务(支持 GPT / Claude / Gemini 等多种模型)。 + +```env +# LLM API 地址(OpenAI 兼容格式) +DEFAULT_LLM_API_URL=https://api.apiyi.com/v1 + +# 你的 API Key(在 APIyi 或其他 LLM 提供商处获取) +# 前端设置面板中也可以动态配置 +``` + +> 也可以使用任何 OpenAI 兼容的 API 服务(如 OpenAI 官方、Azure OpenAI、本地 Ollama 等),只需修改 `DEFAULT_LLM_API_URL` 即可。 + +#### 搜索 API(联网搜索功能需要) + +联网搜索和深度研究报告功能需要配置搜索引擎 API。支持以下任一提供商: + +| 提供商 | 配置方式 | 获取地址 | +|--------|----------|----------| +| **Serper**(推荐) | 环境变量 `SERPER_API_KEY` | [serper.dev](https://serper.dev) | +| **SerpAPI** | 前端传入 `search_api_key` | [serpapi.com](https://serpapi.com) | +| **Google CSE** | 前端传入 `search_api_key` + `google_cse_id` | [programmablesearchengine.google.com](https://programmablesearchengine.google.com) | +| **Brave Search** | 前端传入 `search_api_key` | [brave.com/search/api](https://brave.com/search/api) | +| **博查** | 前端传入 `search_api_key` | [open.bochaai.com](https://open.bochaai.com) | + +Serper 通过后端环境变量配置,其他提供商在前端设置面板中填入对应 API Key 即可。 + +```env +# Serper(Google 搜索),推荐 +SERPER_API_KEY=your_serper_api_key +``` + +#### Supabase(可选,用户管理) + +用于多用户认证与数据隔离。**如果不配置或留空,系统将自动进入体验模式**(无需登录,单用户本地存储,所有核心功能正常使用)。 + +配置后支持:邮箱 + 密码注册登录、OTP 邮件验证、多用户数据隔离(每个用户独立目录)。 + +```env +# 如果不需要多用户功能,可以删除或留空以下配置 +SUPABASE_URL=https://your-project-id.supabase.co +SUPABASE_ANON_KEY=your_supabase_anon_key +SUPABASE_SERVICE_ROLE_KEY=your_supabase_service_role_key +``` + +#### TTS 语音合成(可选,播客功能) + +播客生成功能支持本地 TTS 模型。启用后会自动下载 [Qwen3-TTS](https://huggingface.co/Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice) 模型(约 3.4GB)。 + +```env +# 启用本地 TTS(0=禁用,1=启用) +USE_LOCAL_TTS=1 + +# TTS 引擎:qwen(推荐)或 firered +TTS_ENGINE=qwen + +# 模型空闲自动卸载时间(秒,默认 300 = 5 分钟) +TTS_IDLE_TIMEOUT=300 +``` + +> **提示**:如果不需要播客功能,可以设置 `USE_LOCAL_TTS=0` 或删除此配置以节省磁盘空间。 + +### 3. 启动后端 + +```bash +uvicorn fastapi_app.main:app --host 0.0.0.0 --port 8213 --reload +``` + +后端启动时会自动拉起本地 Embedding 服务(Octen-Embedding-0.6B,默认端口 `26210`),首次启动会下载模型。如需关闭本地 Embedding,设置 `USE_LOCAL_EMBEDDING=0`。 + +- 健康检查:http://localhost:8213/health +- API 文档:http://localhost:8213/docs + +### 4. 启动前端 + +提供中英双前端,任选其一: + +```bash +# 中文前端 +cd frontend_zh && npm install && npm run dev + +# 英文前端 +cd frontend_en && npm install && npm run dev +``` + +访问 http://localhost:3000(或终端提示的端口)。 + +> `npm run dev` 默认读取各前端目录下的 `vite.config.ts`,当前默认端口是 `3000`。 +> 如果使用仓库自带的 `scripts/start.sh`,脚本会启动**中文前端**并强制使用 `3001` 端口,同时启动后端 `8213` 和 `cpolar` 隧道。 + +> 前端的 LLM API 地址和 API Key 可在页面右上角设置面板中动态修改,无需重启。 + +#### 前端配置(可选) + +**本地部署**(前后端在同一台机器):无需配置,默认即可使用。 + +**公网部署**(通过 cpolar/ngrok 等内网穿透工具): + +前端内置智能检测功能: +- 当 `.env` 配置为 `localhost` 但从公网访问时,会自动使用相对路径(当前域名) +- 开发模式下,Vite 会将 `/api` 和 `/outputs` 代理到本地后端 `http://localhost:8213` +- **推荐方式**:使用 nginx 反向代理,将前端和后端统一到同一域名下,无需额外配置 + +> **说明**:上面的 `3000`、`3001`、`8080`、`8213` 只是文档示例端口,实际部署时请按你的前端、后端和代理服务的真实监听端口修改对应配置。 +> 对于个人测试或轻量使用,`scripts/start.sh + Vite 代理 + cpolar` 已可工作;如需更稳定的公网访问或大规模应用,仍推荐使用 nginx 反向代理方案。 +> 当前仓库中的 `scripts/start.sh` 默认使用 `CPOLAR_TUNNEL_NAME=opennotebook`,并显示配置中的 `CPOLAR_PUBLIC_URL`。如果你修改了 cpolar 保留隧道,也请同步修改脚本里的这两个变量。 + +创建 `frontend_zh/.env`(或 `frontend_en/.env`): + +```env +# 后端 API 基础地址(本地开发) +VITE_API_BASE_URL=http://localhost:8213 +``` + +**部署方式对比:** + +| 部署方式 | 配置 | 说明 | +|---------|------|------| +| **本地开发** | `VITE_API_BASE_URL=http://localhost:8213` | 前端和后端都在本地运行 | +| **`scripts/start.sh` 启动** | `VITE_API_BASE_URL=http://localhost:8213` | 当前脚本会启动中文前端 `3001`、后端 `8213`,并通过命名 cpolar 隧道暴露前端 | +| **公网部署(推荐)** | `VITE_API_BASE_URL=http://localhost:8213` | 使用 nginx 反向代理统一域名,智能检测自动切换到相对路径 | +| **公网部署(分离)** | `VITE_API_BASE_URL=https://backend-xxx.cpolar.io` | 前后端使用不同域名,需手动配置后端地址 | + +**推荐:使用 nginx 反向代理统一域名** + +创建 `nginx.conf`: + +```nginx +server { + listen 8080; + + # 前端 + location / { + proxy_pass http://localhost:3000; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + } + + # 后端 API + location /api/ { + proxy_pass http://localhost:8213/api/; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + } + + # 后端输出文件 + location /outputs/ { + proxy_pass http://localhost:8213/outputs/; + } +} +``` + +如果你不是直接运行 `npm run dev`,而是沿用当前仓库的 `scripts/start.sh`,请把上面前端反向代理目标从 `http://localhost:3000` 改成 `http://localhost:3001`。 + +然后使用 cpolar 暴露 nginx 端口: +```bash +cpolar http 8080 +``` + +这样前端和后端在同一域名下,智能检测会自动使用相对路径,无需修改配置。实际部署时请把示例中的端口替换为你的真实端口。 + +> **注意**:修改 `.env` 后需要重新构建前端(`npm run build`)或重启开发服务器(`npm run dev`)。 + +--- + +## 📂 项目结构 + +``` +opennotebookLM/ +├── fastapi_app/ # 后端 API(FastAPI) +│ ├── routers/ # 路由:知识库、认证、Paper2PPT、Paper2Drawio 等 +│ ├── services/ # 业务逻辑:搜索、闪卡、测试题等 +│ ├── config/ # 配置与环境变量 +│ ├── dependencies/ # 依赖注入(认证、Supabase 客户端) +│ ├── middleware/ # 中间件(API Key 校验) +│ └── workflow_adapters/ # 工作流适配层 +├── workflow_engine/ # 工作流引擎(DataFlow-Agent) +│ ├── agentroles/ # Agent 角色定义 +│ ├── workflow/ # 工作流(Paper2PPT、PDF2PPT、Image2Drawio 等) +│ ├── promptstemplates/ # 提示模板 +│ └── toolkits/ # 工具集(搜索、解析等) +├── frontend_en/ # 英文前端(React + Vite + Tailwind) +├── frontend_zh/ # 中文前端 +├── database/ # 数据库迁移脚本 +├── docs/ # 文档 +├── script/ # 辅助脚本(数据库初始化等) +├── static/ # 静态资源 +└── outputs/ # 生成文件输出目录(按用户邮箱隔离) +``` + +--- + +## ⚙️ 模型配置 + +项目采用三层模型配置体系,灵活度从粗到细: + +1. **基础模型层** — 定义可用模型名称(`MODEL_GPT_4O`、`MODEL_CLAUDE_HAIKU` 等) +2. **工作流层** — 为每个工作流设置默认模型(`PAPER2PPT_DEFAULT_MODEL` 等) +3. **角色层** — 精细控制工作流中每个角色使用的模型(`PAPER2PPT_OUTLINE_MODEL` 等) + +详见 `fastapi_app/.env.example` 中的完整配置说明。 + +--- + +## 🗺️ Roadmap + +- [x] 知识库管理(上传文件 / 粘贴网址 / 文本) +- [x] RAG 智能问答 +- [x] PPT 生成 +- [x] 思维导图生成 +- [x] DrawIO 图表生成 +- [x] 知识播客生成 +- [x] 闪卡 & 测试题 +- [x] 联网搜索引入来源 +- [x] 深度研究报告 +- [x] 本地 Embedding 向量检索 +- [x] 用户管理(Supabase 邮箱认证 + 多用户隔离) +- [ ] 视频生成(开发中) +- [ ] 视频来源引入(开发中) +- [ ] 音频来源引入(开发中) + +--- + +## 🤝 参与贡献 + +欢迎提交 Issue 和 Pull Request。详见 [贡献指南](docs/contributing.md)。 + +--- + +## 📄 许可证 + +[Apache License 2.0](LICENSE) + +生成功能基于 [OpenDCAI/Paper2Any](https://github.com/OpenDCAI/Paper2Any)。 + +--- + +
+ +**若本项目对你有帮助,欢迎 ⭐ Star** + +
+ +--- + +## 💬 交流群 + +
+微信交流群 +

扫码加入微信交流群

+
diff --git a/database/00_cleanup.sql b/database/00_cleanup.sql deleted file mode 100644 index 6069f8c..0000000 --- a/database/00_cleanup.sql +++ /dev/null @@ -1,69 +0,0 @@ --- ============================================================================== --- Paper2Any Database Cleanup Script --- --- This script removes all existing tables, functions, triggers, views, --- and storage policies to prepare for a fresh initialization. --- --- INSTRUCTIONS: --- 1. Go to your Supabase Project Dashboard SQL Editor --- 2. Run this script FIRST to clean up existing objects --- 3. Then run 01_init_schema.sql to recreate everything --- --- WARNING: This will DELETE ALL DATA in these tables! --- ============================================================================== - --- ============================================================================== --- Step 1: Drop Storage Policies --- ============================================================================== - -DROP POLICY IF EXISTS "Authenticated users can upload files" ON storage.objects; -DROP POLICY IF EXISTS "Users can view own files" ON storage.objects; -DROP POLICY IF EXISTS "Users can delete own files" ON storage.objects; - --- ============================================================================== --- Step 2: Drop Triggers --- ============================================================================== - -DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users; - --- ============================================================================== --- Step 3: Drop Functions --- ============================================================================== - -DROP FUNCTION IF EXISTS public.handle_new_user() CASCADE; -DROP FUNCTION IF EXISTS public.apply_invite_code(TEXT) CASCADE; -DROP FUNCTION IF EXISTS public.deduct_points(UUID, INTEGER, TEXT) CASCADE; -DROP FUNCTION IF EXISTS public.check_and_grant_daily_usage(UUID) CASCADE; - --- ============================================================================== --- Step 4: Drop Views --- ============================================================================== - -DROP VIEW IF EXISTS public.points_balance CASCADE; - --- ============================================================================== --- Step 5: Drop Tables (CASCADE will remove all dependent objects) --- ============================================================================== - -DROP TABLE IF EXISTS public.usage_records CASCADE; -DROP TABLE IF EXISTS public.user_files CASCADE; -DROP TABLE IF EXISTS public.knowledge_base_files CASCADE; -DROP TABLE IF EXISTS public.referrals CASCADE; -DROP TABLE IF EXISTS public.points_ledger CASCADE; -DROP TABLE IF EXISTS public.profiles CASCADE; - --- ============================================================================== --- Step 6: Delete Storage Objects and Bucket --- ============================================================================== - --- First, delete all objects in the bucket -DELETE FROM storage.objects WHERE bucket_id = 'user-files'; - --- Then, delete the bucket itself -DELETE FROM storage.buckets WHERE id = 'user-files'; - --- ============================================================================== --- Cleanup Complete! --- Now you can run 01_init_schema.sql to recreate everything fresh. --- ============================================================================== - diff --git a/database/01_init_schema.sql b/database/01_init_schema.sql deleted file mode 100644 index 451c5c9..0000000 --- a/database/01_init_schema.sql +++ /dev/null @@ -1,459 +0,0 @@ --- ============================================================================== --- Paper2Any Supabase Schema Setup Script --- --- This script sets up the necessary tables, views, functions, triggers, --- storage buckets, and security policies for the Paper2Any application. --- --- INCLUDES: --- - User management (profiles, referrals, points system) --- - File storage (user_files, knowledge_base_files) --- - Usage tracking and quota management --- - Storage buckets and RLS policies --- --- INSTRUCTIONS: --- 1. Go to your Supabase Project Dashboard: https://supabase.com/dashboard --- 2. Navigate to the "SQL Editor" section. --- 3. Click "New query", paste this entire script, and click "Run". --- --- Last Updated: 2026-01-26 (merged with knowledge base schema) --- ============================================================================== - --- ============================================================================== --- Schema Permissions --- Grant necessary permissions to authenticated users to access public schema --- ============================================================================== - --- CRITICAL: Grant USAGE permission on public schema --- Without this, authenticated users cannot access any tables, views, or functions -GRANT USAGE ON SCHEMA public TO authenticated; - --- Grant ALL privileges on public schema (recommended for Supabase) -GRANT ALL ON SCHEMA public TO authenticated; - --- ============================================================================== --- Table: usage_records --- Tracks API/Workflow usage for quota management. --- ============================================================================== -DROP POLICY IF EXISTS "Authenticated users can upload files" ON storage.objects; -DROP POLICY IF EXISTS "Users can view own files" ON storage.objects; -DROP POLICY IF EXISTS "Users can delete own files" ON storage.objects; - -CREATE TABLE IF NOT EXISTS public.usage_records ( - id UUID DEFAULT gen_random_uuid() PRIMARY KEY, - user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE, - workflow_type TEXT NOT NULL, - called_at TIMESTAMPTZ DEFAULT NOW() -); - --- Enable Row Level Security -ALTER TABLE public.usage_records ENABLE ROW LEVEL SECURITY; - --- Policy: Allow users to insert their own usage records -CREATE POLICY "Allow creation of usage records" -ON public.usage_records -FOR INSERT -WITH CHECK (true); - --- Policy: Allow users to view their own usage records -CREATE POLICY "Allow users to view their own usage" -ON public.usage_records -FOR SELECT -USING (auth.uid() = user_id); - --- ============================================================================== --- Table: user_files --- Stores metadata for generated files. --- ============================================================================== - -CREATE TABLE IF NOT EXISTS public.user_files ( - id UUID DEFAULT gen_random_uuid() PRIMARY KEY, - user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE, - file_name TEXT NOT NULL, - file_size BIGINT, - workflow_type TEXT, - file_path TEXT, - created_at TIMESTAMPTZ DEFAULT NOW() -); - --- Enable Row Level Security -ALTER TABLE public.user_files ENABLE ROW LEVEL SECURITY; - --- Add index for performance -CREATE INDEX IF NOT EXISTS idx_user_files_user_id ON public.user_files(user_id); - --- Policy: Users can only see their own files -CREATE POLICY "Users can view own files" -ON public.user_files -FOR SELECT -USING (auth.uid() = user_id); - --- Policy: Users can insert their own files -CREATE POLICY "Users can upload own files" -ON public.user_files -FOR INSERT -WITH CHECK (auth.uid() = user_id); - --- Policy: Users can delete their own files -CREATE POLICY "Users can delete own files" -ON public.user_files -FOR DELETE -USING (auth.uid() = user_id); - --- ============================================================================== --- Table: knowledge_base_files --- Stores metadata for knowledge base files (PDFs, videos, documents, etc.) --- ============================================================================== - -CREATE TABLE IF NOT EXISTS public.knowledge_base_files ( - id UUID DEFAULT gen_random_uuid() PRIMARY KEY, - user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE, - user_email TEXT, - file_name TEXT NOT NULL, - file_type TEXT, - file_size BIGINT, - storage_path TEXT NOT NULL, - is_embedded BOOLEAN DEFAULT FALSE, - kb_file_id TEXT, - description TEXT, - created_at TIMESTAMPTZ DEFAULT NOW() -); - --- Enable Row Level Security -ALTER TABLE public.knowledge_base_files ENABLE ROW LEVEL SECURITY; - --- Add index for performance -CREATE INDEX IF NOT EXISTS idx_kb_files_user_id ON public.knowledge_base_files(user_id); - --- Policy: Users can only see their own KB files -CREATE POLICY "Users can view own KB files" -ON public.knowledge_base_files -FOR SELECT -USING (auth.uid() = user_id); - --- Policy: Users can insert their own KB files -CREATE POLICY "Users can insert own KB files" -ON public.knowledge_base_files -FOR INSERT -WITH CHECK (auth.uid() = user_id); - --- Policy: Users can delete their own KB files -CREATE POLICY "Users can delete own KB files" -ON public.knowledge_base_files -FOR DELETE -USING (auth.uid() = user_id); - --- Grant table privileges to authenticated role (required in addition to RLS) -GRANT SELECT, INSERT, UPDATE, DELETE ON public.knowledge_base_files TO authenticated; - --- ============================================================================== --- Table: profiles --- Stores user profiles with invite codes. --- ============================================================================== - -CREATE TABLE IF NOT EXISTS public.profiles ( - user_id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE, - invite_code TEXT UNIQUE NOT NULL DEFAULT upper(substr(md5(random()::text), 1, 8)), - created_at TIMESTAMPTZ DEFAULT NOW(), - updated_at TIMESTAMPTZ DEFAULT NOW() -); - --- Enable Row Level Security -ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY; - --- Policy: Users can view own profile -CREATE POLICY "Users can view own profile" -ON public.profiles -FOR SELECT -USING (auth.uid() = user_id); - --- Grant SELECT permission to authenticated users -GRANT SELECT ON public.profiles TO authenticated; - --- ============================================================================== --- Table: referrals --- Tracks who invited whom. --- ============================================================================== - -CREATE TABLE IF NOT EXISTS public.referrals ( - id BIGSERIAL PRIMARY KEY, - inviter_user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE, - invitee_user_id UUID NOT NULL UNIQUE REFERENCES auth.users(id) ON DELETE CASCADE, - invite_code TEXT NOT NULL, - created_at TIMESTAMPTZ DEFAULT NOW() -); - --- Enable Row Level Security -ALTER TABLE public.referrals ENABLE ROW LEVEL SECURITY; - --- Policy: Users can view own referrals -CREATE POLICY "Users can view own referrals" -ON public.referrals -FOR SELECT -USING (auth.uid() = inviter_user_id OR auth.uid() = invitee_user_id); - --- Grant SELECT permission to authenticated users -GRANT SELECT ON public.referrals TO authenticated; - --- ============================================================================== --- Table: points_ledger --- Records all points (usage count) transactions. --- ============================================================================== - -CREATE TABLE IF NOT EXISTS public.points_ledger ( - id BIGSERIAL PRIMARY KEY, - user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE, - points INTEGER NOT NULL, - reason TEXT NOT NULL, - event_key TEXT UNIQUE, - created_at TIMESTAMPTZ DEFAULT NOW() -); - --- Enable Row Level Security -ALTER TABLE public.points_ledger ENABLE ROW LEVEL SECURITY; - --- Policy: Users can view own points -CREATE POLICY "Users can view own points" -ON public.points_ledger -FOR SELECT -USING (auth.uid() = user_id); - --- Grant SELECT permission to authenticated users -GRANT SELECT ON public.points_ledger TO authenticated; - --- ============================================================================== --- View: points_balance --- Calculates current balance per user. --- ============================================================================== - -CREATE OR REPLACE VIEW public.points_balance AS -SELECT - user_id, - COALESCE(SUM(points), 0)::INTEGER AS balance -FROM public.points_ledger -GROUP BY user_id; - --- Grant SELECT permission on points_balance view to authenticated users -GRANT SELECT ON public.points_balance TO authenticated; - --- ============================================================================== --- Function: handle_new_user --- Trigger function to create profile and award signup bonus. --- ============================================================================== - -CREATE OR REPLACE FUNCTION public.handle_new_user() -RETURNS TRIGGER AS $$ -BEGIN - INSERT INTO public.profiles (user_id) - VALUES (NEW.id) - ON CONFLICT (user_id) DO NOTHING; - - -- Award signup bonus: 20 usage counts - INSERT INTO public.points_ledger (user_id, points, reason, event_key) - VALUES (NEW.id, 20, 'signup_bonus', 'signup_bonus_' || NEW.id::text) - ON CONFLICT (event_key) DO NOTHING; - - RETURN NEW; -END; -$$ LANGUAGE plpgsql SECURITY DEFINER; - --- Trigger: on_auth_user_created -DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users; -CREATE TRIGGER on_auth_user_created - AFTER INSERT ON auth.users - FOR EACH ROW EXECUTE FUNCTION public.handle_new_user(); - --- ============================================================================== --- Function: apply_invite_code --- Claims invite code and awards points to both parties. --- ============================================================================== - -CREATE OR REPLACE FUNCTION public.apply_invite_code(p_code TEXT) -RETURNS JSON AS $$ -DECLARE - v_inviter_id UUID; - v_invitee_id UUID := auth.uid(); - v_existing_referral BIGINT; - v_inviter_points INTEGER := 10; - v_invitee_points INTEGER := 10; -BEGIN - -- Check if user is logged in - IF v_invitee_id IS NULL THEN - RETURN json_build_object('success', false, 'error', 'not_authenticated'); - END IF; - - -- Check if already claimed an invite code - SELECT id INTO v_existing_referral - FROM public.referrals - WHERE invitee_user_id = v_invitee_id; - - IF v_existing_referral IS NOT NULL THEN - RETURN json_build_object('success', false, 'error', 'already_claimed'); - END IF; - - -- Find inviter by invite code - SELECT user_id INTO v_inviter_id - FROM public.profiles - WHERE invite_code = UPPER(p_code); - - IF v_inviter_id IS NULL THEN - RETURN json_build_object('success', false, 'error', 'invalid_code'); - END IF; - - -- Cannot invite yourself - IF v_inviter_id = v_invitee_id THEN - RETURN json_build_object('success', false, 'error', 'self_invite'); - END IF; - - -- Create referral record - INSERT INTO public.referrals (inviter_user_id, invitee_user_id, invite_code) - VALUES (v_inviter_id, v_invitee_id, UPPER(p_code)); - - -- Award points to inviter - INSERT INTO public.points_ledger (user_id, points, reason, event_key) - VALUES (v_inviter_id, v_inviter_points, 'referral_inviter', - 'referral_inviter_' || v_inviter_id::text || '_' || v_invitee_id::text); - - -- Award points to invitee - INSERT INTO public.points_ledger (user_id, points, reason, event_key) - VALUES (v_invitee_id, v_invitee_points, 'referral_invitee', - 'referral_invitee_' || v_invitee_id::text); - - RETURN json_build_object('success', true, 'inviter_id', v_inviter_id); -END; -$$ LANGUAGE plpgsql SECURITY DEFINER; - -GRANT EXECUTE ON FUNCTION public.apply_invite_code(TEXT) TO authenticated; - --- ============================================================================== --- Function: deduct_points --- Deducts points from user balance. --- ============================================================================== - -CREATE OR REPLACE FUNCTION public.deduct_points( - p_user_id UUID, - p_amount INTEGER, - p_reason TEXT -) RETURNS BOOLEAN AS $$ -DECLARE - v_current_balance INTEGER; - v_event_key TEXT; -BEGIN - -- Get current balance - SELECT balance INTO v_current_balance - FROM public.points_balance - WHERE user_id = p_user_id; - - -- If no balance record exists, user has 0 points - IF v_current_balance IS NULL THEN - v_current_balance := 0; - END IF; - - -- Check if user has enough points - IF v_current_balance < p_amount THEN - RETURN FALSE; - END IF; - - -- Generate unique event_key using timestamp - v_event_key := p_reason || '_' || p_user_id::text || '_' || extract(epoch from now())::text; - - -- Deduct points by inserting negative ledger entry - INSERT INTO public.points_ledger (user_id, points, reason, event_key) - VALUES (p_user_id, -p_amount, p_reason, v_event_key); - - RETURN TRUE; -END; -$$ LANGUAGE plpgsql SECURITY DEFINER; - -GRANT EXECUTE ON FUNCTION public.deduct_points(UUID, INTEGER, TEXT) TO authenticated; - --- ============================================================================== --- Function: check_and_grant_daily_usage --- Grants 10 daily usage counts if user balance <= 30. --- ============================================================================== - -CREATE OR REPLACE FUNCTION public.check_and_grant_daily_usage(p_user_id UUID) -RETURNS INTEGER AS $$ -DECLARE - v_balance INTEGER; - v_event_key TEXT; -BEGIN - -- Get current balance from view - SELECT balance INTO v_balance - FROM public.points_balance - WHERE user_id = p_user_id; - - -- If no balance record exists, user has 0 points - IF v_balance IS NULL THEN - v_balance := 0; - END IF; - - -- Check if balance > 30, no daily grant - IF v_balance > 30 THEN - RETURN v_balance; - END IF; - - -- Generate event_key for today's grant (idempotency) - v_event_key := 'daily_grant_' || CURRENT_DATE::text || '_' || p_user_id::text; - - -- Grant 10 usage counts (idempotent insert using event_key) - INSERT INTO public.points_ledger (user_id, points, reason, event_key) - VALUES (p_user_id, 10, 'daily_grant', v_event_key) - ON CONFLICT (event_key) DO NOTHING; - - -- Return new balance (recalculate from view) - SELECT balance INTO v_balance - FROM public.points_balance - WHERE user_id = p_user_id; - - RETURN COALESCE(v_balance, 0); -END; -$$ LANGUAGE plpgsql SECURITY DEFINER; - -GRANT EXECUTE ON FUNCTION public.check_and_grant_daily_usage(UUID) TO authenticated; - -COMMENT ON FUNCTION public.check_and_grant_daily_usage IS -'Grants 10 daily usage counts if user balance <= 30. Idempotent - safe to call multiple times per day.'; - --- ============================================================================== --- Storage Bucket: user-files --- Stores the actual binary files (PDFs, PPTs, Images). --- ============================================================================== - --- Create the bucket if it doesn't exist -INSERT INTO storage.buckets (id, name, public) -VALUES ('user-files', 'user-files', true) -ON CONFLICT (id) DO NOTHING; - --- Policy: Allow authenticated users to upload files to their own folder -CREATE POLICY "Authenticated users can upload files" -ON storage.objects -FOR INSERT -TO authenticated -WITH CHECK ( - bucket_id = 'user-files' AND - (storage.foldername(name))[1] = auth.uid()::text -); - --- Policy: Users can view/download their own files -CREATE POLICY "Users can view own files" -ON storage.objects -FOR SELECT -TO authenticated -USING ( - bucket_id = 'user-files' AND - (storage.foldername(name))[1] = auth.uid()::text -); - --- Policy: Users can delete their own files -CREATE POLICY "Users can delete own files" -ON storage.objects -FOR DELETE -TO authenticated -USING ( - bucket_id = 'user-files' AND - (storage.foldername(name))[1] = auth.uid()::text -); - --- ============================================================================== --- Done! --- ============================================================================== diff --git a/dataflow_agent/agentroles/common_agents/imagetextbboxagent_agent.py b/dataflow_agent/agentroles/common_agents/imagetextbboxagent_agent.py deleted file mode 100644 index c970fa3..0000000 --- a/dataflow_agent/agentroles/common_agents/imagetextbboxagent_agent.py +++ /dev/null @@ -1,165 +0,0 @@ -""" -Imagetextbboxagent agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2026-01-12 19:13:38 -生成位置: dataflow_agent/agentroles/common_agents/imagetextbboxagent_agent.py - -本文件由 `dfa create --agent_name ImageTextBBoxAgent` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("ImageTextBBoxAgent") -class Imagetextbboxagent(BaseAgent): - """TODO: 描述 ImageTextBBoxAgent 的职责""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "ImageTextBBoxAgent" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_image_text_bbox_agent" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_image_text_bbox_agent" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数 - 提示词中的占位符: - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - """ - # TODO: 按需补充 - return {} - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写""" - - state.bbox_result = result - - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def imagetextbboxagent( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """ImageTextBBoxAgent 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = Imagetextbboxagent( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_imagetextbboxagent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> Imagetextbboxagent: - return Imagetextbboxagent.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) \ No newline at end of file diff --git a/dataflow_agent/agentroles/common_agents/test_graph_agent.py b/dataflow_agent/agentroles/common_agents/test_graph_agent.py deleted file mode 100644 index adadb8a..0000000 --- a/dataflow_agent/agentroles/common_agents/test_graph_agent.py +++ /dev/null @@ -1,166 +0,0 @@ -""" -TestGraph agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-01 20:16:36 -生成位置: dataflow_agent/agentroles/common_agents/test_graph_agent.py - -本文件由 `dfa create --agent_name test_graph` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("test_graph") -class TestGraph(BaseAgent): - """TODO: 描述 test_graph 的职责""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "test_graph" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_test_graph" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_test_graph" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数 - 提示词中的占位符: - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - """ - # TODO: 按需补充 - return { - "purpose": pre_tool_results.get("purpose", ""), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写""" - - # state. = result - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def test_graph( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """test_graph 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = TestGraph( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_test_graph( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> TestGraph: - return TestGraph.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/__init__.py b/dataflow_agent/agentroles/paper2any_agents/__init__.py deleted file mode 100644 index e7ae9cd..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# dataflow_agent/agentroles/paper2any_agents/__init__.py - -from .paper_idea_extractor import PaperIdeaExtractor, create_paper_idea_extractor -from .chart_type_recommender import ChartTypeRecommender, create_chart_type_recommender -from .chart_code_generator import ChartCodeGenerator, create_chart_code_generator -from .fig_desc_generator import FigureDescGenerator -from .deep_research_agent import DeepResearchAgent, create_deep_research_agent - -__all__ = [ - "PaperIdeaExtractor", - "create_paper_idea_extractor", - "ChartTypeRecommender", - "create_chart_type_recommender", - "ChartCodeGenerator", - "create_chart_code_generator", - "FigureDescGenerator", - "DeepResearchAgent", - "create_deep_research_agent", -] diff --git a/dataflow_agent/agentroles/paper2any_agents/chart_code_generator.py b/dataflow_agent/agentroles/paper2any_agents/chart_code_generator.py deleted file mode 100644 index e3d12cb..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/chart_code_generator.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -ChartCodeGenerator agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -根据图表配置和表格数据生成 matplotlib 代码 - -用于 Paper2ExpFigure 工作流 -""" - -from __future__ import annotations - -import json -from typing import Any, Dict, List, Optional - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.agentroles.cores.base_agent import BaseAgent - -log = get_logger(__name__) - - -@register("chart_code_generator") -class ChartCodeGenerator(BaseAgent): - """根据图表配置生成 matplotlib 代码的 Agent""" - - @property - def role_name(self) -> str: - return "chart_code_generator" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_chart_code_generator" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_chart_code_generator" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """从 pre_tool_results 中获取 prompt 参数""" - paper_idea = pre_tool_results.get("paper_idea", "") - chart_config = pre_tool_results.get("chart_config", {}) - table_caption = pre_tool_results.get("table_caption", "") - # 将配置格式化为 JSON 字符串 - if isinstance(chart_config, dict): - chart_config_str = ( - f"你要生成的图表类型:{chart_config.get('chart_type', '')}\n" - f"为什么要选择这种类型:{chart_config.get('chart_type_reason', '')}\n" - f"对图表预期的描述:{chart_config.get('chart_desc', '')}\n" - ) - else: - chart_config_str = str(chart_config) - - return { - "paper_idea": paper_idea, - "chart_config": chart_config_str, - "table_caption": table_caption, - - # "table_headers": headers_str, - # "table_rows": rows_str, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """默认的 pre_tool_results""" - return { - "paper_idea": "", - "chart_config": {}, - "table_caption": "", - # "table_headers": [], - # "table_rows": [], - } - - def update_state_result( - self, - state: Paper2FigureState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将生成的代码写入 state.generated_codes""" - try: - if isinstance(result, dict): - code = result.get("code", "") - description = result.get("description", "") - - if code: - # 获取当前正在处理的 table_id - # 优先使用state里面传入的,因为vlm模式下,base_agent更新状态时不给提供pre_tool_results - if state.pre_tool_results.get("chart_config", {}): - chart_config = state.pre_tool_results.get("chart_config", {}) - else: - chart_config = pre_tool_results.get("chart_config", {}) - table_id = chart_config.get("table_id", f"table_{len(state.generated_codes)}") - - code_entry = { - "table_id": table_id, - "code": code, - "description": description, - } - state.generated_codes[table_id] = code_entry - log.info(f"[ChartCodeGenerator] 生成代码: {table_id}, 长度: {len(code)}") - except Exception as e: - log.warning(f"[ChartCodeGenerator] 更新 state 失败: {e}") - - return super().update_state_result(state, result, pre_tool_results) - - - async def execute_pre_tools(self, state: MainState) -> Dict[str, Any]: - """重写 execute_pre_tools,方便并行调用时注入前置工具结果""" - results = await super().execute_pre_tools(state) - - inject_results = state.pre_tool_results - for key, value in inject_results.items(): - if value: - results.update( - {key: value} - ) - - return results - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def chart_code_generator( - state: Paper2FigureState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> Paper2FigureState: - """ChartCodeGenerator 的异步入口""" - inst = create_chart_code_generator( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - vlm_config=vlm_config, - ) - return await inst.execute(state, use_agent=use_agent, **kwargs) - - -def create_chart_code_generator( - tool_manager: Optional[ToolManager] = None, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> ChartCodeGenerator: - """创建 ChartCodeGenerator 实例""" - if tool_manager is None: - from dataflow_agent.toolkits.tool_manager import get_tool_manager - tool_manager = get_tool_manager() - return ChartCodeGenerator( - tool_manager=tool_manager, - vlm_config=vlm_config, - use_vlm=True, - **kwargs - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/chart_type_recommender.py b/dataflow_agent/agentroles/paper2any_agents/chart_type_recommender.py deleted file mode 100644 index 0fa0596..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/chart_type_recommender.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -ChartTypeRecommender agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -根据论文核心思想和表格数据推荐合适的图表类型 - -用于 Paper2ExpFigure 工作流 -""" - -from __future__ import annotations - -import json -from typing import Any, Dict, List, Optional - -from dataflow_agent.agentroles.cores.strategies import VLMStrategy -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.agentroles.cores.base_agent import BaseAgent - -log = get_logger(__name__) - - -@register("chart_type_recommender") -class ChartTypeRecommender(BaseAgent): - """根据论文思想和表格数据推荐图表类型的 Agent""" - - @property - def role_name(self) -> str: - return "chart_type_recommender" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_chart_type_recommender" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_chart_type_recommender" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """从 pre_tool_results 中获取 prompt 参数""" - print("ChartTypeRecommender get_task_prompt_params", pre_tool_results) - paper_idea = pre_tool_results.get("paper_idea", "") - table_info = pre_tool_results.get("table_info", {}) - log.info(f"[ChartTypeRecommender] 获取 prompt 参数: paper_idea={paper_idea}, table_info={table_info}") - - # 格式化表格信息为字符串 - if isinstance(table_info, dict): - table_info_str = json.dumps(table_info, ensure_ascii=False, indent=2) - else: - table_info_str = str(table_info) - - return { - "paper_idea": paper_idea, - "table_info": table_info_str, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """默认的 pre_tool_results""" - return { - "paper_idea": "", - "table_info": {}, - } - - def update_state_result( - self, - state: Paper2FigureState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推荐结果写入 state.chart_configs""" - try: - if isinstance(result, dict): - # 检查表格是否适合绘图 - is_suitable = result.get("is_suitable_for_chart", True) - suitability_reason = result.get("suitability_reason", "") - - chart_type = result.get("chart_type", "bar") - chart_type_reason = result.get("chart_type_reason", "") - chart_desc = result.get("chart_desc", "") - - # 获取当前正在处理的 table_id - # 优先使用state里面传入的,因为vlm模式下,base_agent更新状态时不给提供pre_tool_results - if state.pre_tool_results.get("table_info", {}): - table_info = state.pre_tool_results.get("table_info", {}) - else: - table_info = pre_tool_results.get("table_info", {}) - table_id = table_info.get("table_id", f"table_{len(state.chart_configs)}") - - # 如果表格不适合绘图,记录原因并跳过 - if not is_suitable or chart_type == "none": - log.info( - f"[ChartTypeRecommender] 表格 {table_id} 不适合绘图: {suitability_reason}" - ) - # 仍然添加配置,但标记为 chart_type="none" - chart_config = { - "table_id": table_id, - "is_suitable_for_chart": False, - "suitability_reason": suitability_reason, - "chart_type": "none", - "chart_type_reason": chart_type_reason, - } - else: - # 构建完整的图表配置 - chart_config = { - "table_id": table_id, - "is_suitable_for_chart": True, - "suitability_reason": suitability_reason, - "chart_type": chart_type, - "chart_type_reason": chart_type_reason, - "chart_desc": chart_desc, - } - log.info(f"[ChartTypeRecommender] 推荐图表类型: {table_id} -> {chart_type}") - - state.chart_configs[table_id] = chart_config - except Exception as e: - log.warning(f"[ChartTypeRecommender] 更新 state 失败: {e}") - - return super().update_state_result(state, result, pre_tool_results) - - async def execute_pre_tools(self, state: MainState) -> Dict[str, Any]: - """重写 execute_pre_tools,方便并行调用时注入前置工具结果""" - results = await super().execute_pre_tools(state) - - inject_results = state.pre_tool_results - for key, value in inject_results.items(): - if value: - results.update( - {key: value} - ) - - return results - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def chart_type_recommender( - state: Paper2FigureState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 2048, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> Paper2FigureState: - """ChartTypeRecommender 的异步入口""" - inst = create_chart_type_recommender( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - vlm_config=vlm_config, - ) - return await inst.execute(state, use_agent=use_agent, **kwargs) - - -def create_chart_type_recommender( - tool_manager: Optional[ToolManager] = None, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs -) -> ChartTypeRecommender: - """创建 ChartTypeRecommender 实例""" - if tool_manager is None: - from dataflow_agent.toolkits.tool_manager import get_tool_manager - tool_manager = get_tool_manager() - - return ChartTypeRecommender( - tool_manager=tool_manager, - vlm_config=vlm_config, - use_vlm=True, - **kwargs - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/content_expander_agent.py b/dataflow_agent/agentroles/paper2any_agents/content_expander_agent.py deleted file mode 100644 index 45c4c4c..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/content_expander_agent.py +++ /dev/null @@ -1,141 +0,0 @@ -""" -ContentExpander agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Description: 负责对输入文本进行扩写,使其达到足够的长度。 -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("content_expander") -class ContentExpander(BaseAgent): - """ - ContentExpander: 接收文本,进行迭代扩写。 - """ - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: - return "content_expander" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_content_expander" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_content_expander" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """ - 构造 Prompt 参数。 - 需要 Workflow 传入: - - text_content: 待扩写的文本 - - expansion_round: 当前扩写轮次 - """ - return { - "text_content": self.state.text_content, - "expansion_round": 0, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "text_content": "", - "expansion_round": 0 - } - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """ - 将扩写后的文本(字符串)写回 State。 - """ - state.text_content = result - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def content_expander( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "text", # 默认返回文本 - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - agent = ContentExpander( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_content_expander( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "text", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> ContentExpander: - return ContentExpander.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/deep_research_agent.py b/dataflow_agent/agentroles/paper2any_agents/deep_research_agent.py deleted file mode 100644 index 567739f..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/deep_research_agent.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -DeepResearchAgent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -用于根据 Topic 生成详细的长篇研究报告/大纲内容。 -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("deep_research_agent") -class DeepResearchAgent(BaseAgent): - """ - DeepResearchAgent: 接收 Topic,输出长篇研究报告。 - """ - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: - return "deep_research_agent" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_deep_research_agent" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_deep_research_agent" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "text_content": self.state.text_content or "", - "language": getattr(self.state.request, "language", "zh"), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """ - Deep Research 的结果通常是长文本。 - 如果是 parser_type="text",result 就是 str。 - 如果是 parser_type="json",result 是 dict。 - 这里假设我们使用 text parser,或者 json 中有个 'content' 字段。 - """ - # 假设 LLM 直接返回文本内容,或者 JSON 中包含 content - content = "" - if isinstance(result, str): - content = result - elif isinstance(result, dict): - # 尝试获取常见字段 - content = result.get("content") or result.get("report") or result.get("research_result") or str(result) - - # 将生成的长文本写回 state.text_content,供后续 outline_agent 使用 - if content: - state.text_content = content - log.info(f"[deep_research_agent]: Generated content length: {len(content)}") - - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -def create_deep_research_agent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 8192, - parser_type: str = "text", # 默认为 text,直接输出长文 - **kwargs, -) -> DeepResearchAgent: - return DeepResearchAgent.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - parser_type=parser_type, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/diagram_editor.py b/dataflow_agent/agentroles/paper2any_agents/diagram_editor.py deleted file mode 100644 index aca7de5..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/diagram_editor.py +++ /dev/null @@ -1,151 +0,0 @@ -""" -DiagramEditor agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Edit existing draw.io XML based on user instructions. -""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Tuple - -from dataflow_agent.state import Paper2DrawioState -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.agentroles.cores.base_agent import BaseAgent, ValidatorFunc -from dataflow_agent.toolkits.drawio_tools import validate_xml, sanitize_cells_xml - -log = get_logger(__name__) - - -@register("diagram_editor") -class DiagramEditor(BaseAgent): - """Edit existing draw.io XML.""" - - def __init__(self, **kwargs): - kwargs["parser_type"] = "text" - super().__init__(**kwargs) - - @property - def role_name(self) -> str: - return "diagram_editor" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_diagram_editor" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_diagram_editor" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "current_xml": pre_tool_results.get("current_xml", ""), - "edit_instruction": pre_tool_results.get("edit_instruction", ""), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "current_xml": "", - "edit_instruction": "", - } - - def get_react_validators(self) -> List[ValidatorFunc]: - return [ - self._validator_has_mxcell, - self._validator_no_markdown, - self._validator_no_xml_comments, - self._validator_xml_valid, - ] - - @staticmethod - def _extract_xml_text(content: str, parsed_result: Dict[str, Any]) -> str: - if isinstance(parsed_result, dict): - return ( - parsed_result.get("text", "") - or parsed_result.get("xml", "") - or parsed_result.get("drawio_xml", "") - or "" - ) - return content or "" - - @classmethod - def _validator_has_mxcell( - cls, - content: str, - parsed_result: Dict[str, Any], - ) -> Tuple[bool, Optional[str]]: - xml_text = cls._extract_xml_text(content, parsed_result).strip() - if not xml_text or " Tuple[bool, Optional[str]]: - xml_text = cls._extract_xml_text(content, parsed_result) - if "```" in xml_text or "```" in content: - return False, "不要输出 markdown 代码块标记(```),仅输出 mxCell XML。" - return True, None - - @classmethod - def _validator_no_xml_comments( - cls, - content: str, - parsed_result: Dict[str, Any], - ) -> Tuple[bool, Optional[str]]: - xml_text = cls._extract_xml_text(content, parsed_result) - if "),仅输出 mxCell XML。" - return True, None - - @classmethod - def _validator_xml_valid( - cls, - content: str, - parsed_result: Dict[str, Any], - ) -> Tuple[bool, Optional[str]]: - xml_text = cls._extract_xml_text(content, parsed_result).strip() - if not xml_text: - return False, "输出不能为空,请返回 mxCell XML。" - is_valid, errors = validate_xml(xml_text) - if not is_valid: - hint = "请确保所有 & < > 等特殊字符已转义(例如 &),并只输出 mxCell 元素。" - return False, f"XML 解析失败: {'; '.join(errors)}。{hint}" - return True, None - - def update_state_result( - self, - state: Paper2DrawioState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - try: - if isinstance(result, dict): - xml_content = result.get("text", "") or result.get("xml", "") or result.get("drawio_xml", "") - elif isinstance(result, str): - xml_content = result - else: - xml_content = str(result) - - if xml_content: - xml_content = xml_content.strip() - if xml_content.startswith("```xml"): - xml_content = xml_content[6:] - elif xml_content.startswith("```"): - xml_content = xml_content[3:] - if xml_content.endswith("```"): - xml_content = xml_content[:-3] - xml_content = sanitize_cells_xml(xml_content) - xml_content = xml_content.strip() - - if xml_content: - state.drawio_xml = xml_content - state.drawio_xml_history.append(xml_content) - log.info(f"[DiagramEditor] XML updated, length: {len(xml_content)}") - else: - log.warning("[DiagramEditor] No valid XML produced") - except Exception as e: - log.error(f"[DiagramEditor] Failed to update state: {e}") diff --git a/dataflow_agent/agentroles/paper2any_agents/diagram_planner.py b/dataflow_agent/agentroles/paper2any_agents/diagram_planner.py deleted file mode 100644 index 0a60282..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/diagram_planner.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -DiagramPlanner agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -分析内容并规划图表结构 - -用于 Paper2Drawio 工作流 -""" - -from __future__ import annotations - -from typing import Any, Dict - -from dataflow_agent.state import Paper2DrawioState -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.agentroles.cores.base_agent import BaseAgent - -log = get_logger(__name__) - - -@register("diagram_planner") -class DiagramPlanner(BaseAgent): - """规划图表结构的 Agent""" - - @property - def role_name(self) -> str: - return "diagram_planner" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_diagram_planner" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_diagram_planner" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """从 pre_tool_results 中获取 prompt 参数""" - paper_content = pre_tool_results.get("paper_content", "") - text_content = pre_tool_results.get("text_content", "") - diagram_type = pre_tool_results.get("diagram_type", "auto") - language = pre_tool_results.get("language", "") - - return { - "paper_content": paper_content, - "text_content": text_content, - "diagram_type": diagram_type, - "language": language, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """默认的 pre_tool_results""" - return { - "paper_content": "", - "text_content": "", - "diagram_type": "auto", - "language": "", - } - - def update_state_result( - self, - state: Paper2DrawioState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将规划结果写入 state""" - try: - if isinstance(result, dict): - plan = result.get("diagram_plan", "") or result.get("plan", "") - elif isinstance(result, str): - plan = result - else: - plan = str(result) - - if plan: - state.diagram_plan = plan - log.info(f"[DiagramPlanner] 规划完成,长度: {len(plan)}") - else: - log.warning("[DiagramPlanner] 未生成有效的规划") - - except Exception as e: - log.error(f"[DiagramPlanner] 更新状态失败: {e}") diff --git a/dataflow_agent/agentroles/paper2any_agents/diagram_vlm_validator.py b/dataflow_agent/agentroles/paper2any_agents/diagram_vlm_validator.py deleted file mode 100644 index 1cb8afc..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/diagram_vlm_validator.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -DiagramVlmValidator agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -VLM-based diagram validation for draw.io outputs. -""" - -from __future__ import annotations - -from typing import Any, Dict - -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - - -@register("diagram_vlm_validator") -class DiagramVlmValidator(BaseAgent): - """Validate rendered diagram image with VLM.""" - - @property - def role_name(self) -> str: - return "diagram_vlm_validator" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_drawio_vlm_validator" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_drawio_vlm_validator" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "diagram_type": pre_tool_results.get("diagram_type", "auto"), - "diagram_xml": pre_tool_results.get("diagram_xml", ""), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "diagram_type": "auto", - "diagram_xml": "", - } diff --git a/dataflow_agent/agentroles/paper2any_agents/drawio_xml_generator.py b/dataflow_agent/agentroles/paper2any_agents/drawio_xml_generator.py deleted file mode 100644 index 86b2fe7..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/drawio_xml_generator.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -DrawioXmlGenerator agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成 draw.io XML 格式的图表 - -用于 Paper2Drawio 工作流 -""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Tuple - -from dataflow_agent.state import Paper2DrawioState -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.agentroles.cores.base_agent import BaseAgent, ValidatorFunc -from dataflow_agent.toolkits.drawio_tools import validate_xml, sanitize_cells_xml - -log = get_logger(__name__) - - -@register("drawio_xml_generator") -class DrawioXmlGenerator(BaseAgent): - """生成 draw.io XML 的 Agent""" - - def __init__(self, **kwargs): - kwargs['parser_type'] = 'text' - super().__init__(**kwargs) - - @property - def role_name(self) -> str: - return "drawio_xml_generator" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_drawio_xml_generator" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_drawio_xml_generator" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """从 pre_tool_results 中获取 prompt 参数""" - diagram_plan = pre_tool_results.get("diagram_plan", "") - diagram_type = pre_tool_results.get("diagram_type", "auto") - diagram_style = pre_tool_results.get("diagram_style", "default") - text_content = pre_tool_results.get("text_content", "") - validation_feedback = pre_tool_results.get("validation_feedback", "") - language = pre_tool_results.get("language", "") - - return { - "diagram_plan": diagram_plan, - "diagram_type": diagram_type, - "diagram_style": diagram_style, - "text_content": text_content, - "validation_feedback": validation_feedback, - "language": language, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """默认的 pre_tool_results""" - return { - "diagram_plan": "", - "diagram_type": "auto", - "diagram_style": "default", - "text_content": "", - "validation_feedback": "", - "language": "", - } - - def get_react_validators(self) -> List[ValidatorFunc]: - return [ - self._validator_has_mxcell, - self._validator_no_markdown, - self._validator_no_xml_comments, - self._validator_xml_valid, - ] - - @staticmethod - def _extract_xml_text(content: str, parsed_result: Dict[str, Any]) -> str: - if isinstance(parsed_result, dict): - return ( - parsed_result.get("text", "") - or parsed_result.get("xml", "") - or parsed_result.get("drawio_xml", "") - or "" - ) - return content or "" - - @classmethod - def _validator_has_mxcell( - cls, - content: str, - parsed_result: Dict[str, Any], - ) -> Tuple[bool, Optional[str]]: - xml_text = cls._extract_xml_text(content, parsed_result).strip() - if not xml_text or " Tuple[bool, Optional[str]]: - xml_text = cls._extract_xml_text(content, parsed_result) - if "```" in xml_text or "```" in content: - return False, "不要输出 markdown 代码块标记(```),仅输出 mxCell XML。" - return True, None - - @classmethod - def _validator_no_xml_comments( - cls, - content: str, - parsed_result: Dict[str, Any], - ) -> Tuple[bool, Optional[str]]: - xml_text = cls._extract_xml_text(content, parsed_result) - if "),仅输出 mxCell XML。" - return True, None - - @classmethod - def _validator_xml_valid( - cls, - content: str, - parsed_result: Dict[str, Any], - ) -> Tuple[bool, Optional[str]]: - xml_text = cls._extract_xml_text(content, parsed_result).strip() - if not xml_text: - return False, "输出不能为空,请返回 mxCell XML。" - is_valid, errors = validate_xml(xml_text) - if not is_valid: - hint = "请确保所有 & < > 等特殊字符已转义(例如 &),并只输出 mxCell 元素。" - return False, f"XML 解析失败: {'; '.join(errors)}。{hint}" - return True, None - - def update_state_result( - self, - state: Paper2DrawioState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将生成的 XML 写入 state""" - try: - if isinstance(result, dict): - xml_content = result.get("text", "") or result.get("xml", "") or result.get("drawio_xml", "") - elif isinstance(result, str): - xml_content = result - else: - xml_content = str(result) - - # 清理 markdown 代码块标记 - if xml_content: - xml_content = xml_content.strip() - if xml_content.startswith("```xml"): - xml_content = xml_content[6:] - elif xml_content.startswith("```"): - xml_content = xml_content[3:] - if xml_content.endswith("```"): - xml_content = xml_content[:-3] - xml_content = sanitize_cells_xml(xml_content) - # NOTE: Temporarily disable hard overlap resolver per request. - # diagram_type = pre_tool_results.get("diagram_type", "auto") - # xml_content = resolve_overlaps(xml_content, diagram_type=diagram_type) - xml_content = xml_content.strip() - - if xml_content: - state.drawio_xml = xml_content - state.drawio_xml_history.append(xml_content) - log.info(f"[DrawioXmlGenerator] XML 生成成功,长度: {len(xml_content)}") - else: - log.warning("[DrawioXmlGenerator] 未生成有效的 XML") - - except Exception as e: - log.error(f"[DrawioXmlGenerator] 更新状态失败: {e}") diff --git a/dataflow_agent/agentroles/paper2any_agents/fig_desc_generator.py b/dataflow_agent/agentroles/paper2any_agents/fig_desc_generator.py deleted file mode 100644 index 690f5d2..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/fig_desc_generator.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations -from typing import Any, Dict, Optional -from dataflow_agent.state import DFState, Paper2FigureState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -log = get_logger(__name__) -from dataflow_agent.agentroles.cores.base_agent import BaseAgent - -@register("figure_desc_generator") -class FigureDescGenerator(BaseAgent): - @property - def role_name(self) -> str: - return "figure_desc_generator" - - @property - def system_prompt_template_name(self) -> str: - if getattr(self.state.request, "figure_complex", "") == "easy": - return "system_prompt_for_figure_desc_generator" - elif getattr(self.state.request, "figure_complex", "") == "hard": - return "system_prompt_for_figure_desc_generator_free" - else: - return "system_prompt_for_figure_desc_generator_mid" - - @property - def task_prompt_template_name(self) -> str: - if getattr(self.state.request, "figure_complex", "") == "easy": - return "task_prompt_for_figure_desc_generator" - elif getattr(self.state.request, "figure_complex", "") == "hard": - return "task_prompt_for_figure_desc_generator_free" - else: - return "task_prompt_for_figure_desc_generator_mid" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - paper_idea = pre_tool_results.get("paper_idea") - return { - "paper_idea": paper_idea, - "style": self.state.request.style - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "paper_idea": "", - } - - def update_state_result(self, state: Paper2FigureState, result: Dict[str, Any], pre_tool_results: Dict[str, Any]): - try: - figure_desc = result.get("figure_desc", "") if isinstance(result, dict) else "" - # print(result) - if figure_desc: - state.fig_desc = figure_desc - except Exception: - pass - return super().update_state_result(state, result, pre_tool_results) - - -# Function to generate figure description -async def figure_desc_generator( - state: Paper2FigureState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 2048, - use_agent: bool = False, - **kwargs, -) -> Paper2FigureState: - inst = create_figure_desc_generator( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - ) - return await inst.execute(state, use_agent=use_agent, **kwargs) - - -def create_figure_desc_generator(tool_manager: Optional[ToolManager] = None, **kwargs) -> FigureDescGenerator: - if tool_manager is None: - from dataflow_agent.toolkits.tool_manager import get_tool_manager - tool_manager = get_tool_manager() - return FigureDescGenerator(tool_manager=tool_manager, **kwargs) diff --git a/dataflow_agent/agentroles/paper2any_agents/icon_editor.py b/dataflow_agent/agentroles/paper2any_agents/icon_editor.py deleted file mode 100644 index 9a62917..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/icon_editor.py +++ /dev/null @@ -1,94 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import Any, Dict, List, Optional -from langchain_openai import ChatOpenAI -from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage -from langchain_core.tools import Tool - -from dataflow_agent.promptstemplates.prompt_template import PromptsTemplateGenerator -from dataflow_agent.state import DFState -from dataflow_agent.utils import robust_parse_json -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -log = get_logger(__name__) - -from dataflow_agent.agentroles.cores.base_agent import BaseAgent - -@register("icon_editor") -class IconEditor(BaseAgent): - """图标编辑器 - 根据图片和提示词进行二次编辑""" - - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - @property - def role_name(self) -> str: - return "icon_editor" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_icon_editing" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_icon_editing" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """图标编辑器特有的提示词参数""" - return { - 'original_image': pre_tool_results.get('image', ''), - 'edit_instructions': pre_tool_results.get('instructions', ''), - 'edit_strength': pre_tool_results.get('strength', 0.8), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """图标编辑器的默认前置工具结果""" - return { - 'image': '', - 'instructions': '', - 'strength': 0.8 - } - - def update_state_result(self, state: DFState, result: Dict[str, Any], pre_tool_results: Dict[str, Any]): - """自定义状态更新 - 保存编辑后的图片""" - # 假设 result 包含编辑后的图片URL或base64数据 - if isinstance(result, dict): - state.icon_image = result.get('edited_image_url', result.get('edited_image_data', result)) - # 可选:保存编辑历史 - if not hasattr(state, 'icon_edit_history'): - state.icon_edit_history = [] - state.icon_edit_history.append({ - 'original': pre_tool_results.get('image', ''), - 'instructions': pre_tool_results.get('instructions', ''), - 'result': state.icon_image - }) - else: - state.icon_image = result - super().update_state_result(state, result, pre_tool_results) - - -async def icon_editing( - state: DFState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 512, - use_agent: bool = False, - **kwargs, -) -> DFState: - """编辑图标的入口函数""" - editor = IconEditor( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - ) - return await editor.execute(state, use_agent=use_agent, **kwargs) - - -def create_icon_editor(tool_manager: Optional[ToolManager] = None, **kwargs) -> IconEditor: - """创建图标编辑器实例""" - return IconEditor(tool_manager=tool_manager, **kwargs) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/icon_generator.py b/dataflow_agent/agentroles/paper2any_agents/icon_generator.py deleted file mode 100644 index 2336891..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/icon_generator.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import Any, Dict, List, Optional -from langchain_openai import ChatOpenAI -from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage -from langchain_core.tools import Tool - -from dataflow_agent.promptstemplates.prompt_template import PromptsTemplateGenerator -from dataflow_agent.state import DFState -from dataflow_agent.utils import robust_parse_json -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -log = get_logger(__name__) - -from dataflow_agent.agentroles.cores.base_agent import BaseAgent - -@register("icon_generator") -class IconGenerator(BaseAgent): - """图标生成器 - 根据text2img prompt生成图标图片""" - - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - @property - def role_name(self) -> str: - return "icon_generator" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_icon_generation" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_icon_generation" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """图标生成器特有的提示词参数""" - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """图标生成器的默认前置工具结果""" - return { - 'prompt': '', - 'size': '512x512', - 'num_images': 1 - } - - def update_state_result(self, state: DFState, result: Dict[str, Any], pre_tool_results: Dict[str, Any]): - """自定义状态更新 - 保存生成的图片""" - # 假设 result 包含图片URL或base64数据 - if isinstance(result, dict): - state.icon_image = result.get('image_url', result.get('image_data', result)) - else: - state.icon_image = result - super().update_state_result(state, result, pre_tool_results) - - -async def icon_generation( - state: DFState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 512, - use_agent: bool = False, - **kwargs, -) -> DFState: - """生成图标的入口函数""" - generator = IconGenerator( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - ) - return await generator.execute(state, use_agent=use_agent, **kwargs) - - -def create_icon_generator(tool_manager: Optional[ToolManager] = None, **kwargs) -> IconGenerator: - """创建图标生成器实例""" - return IconGenerator(tool_manager=tool_manager, **kwargs) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/icon_prompt_generator.py b/dataflow_agent/agentroles/paper2any_agents/icon_prompt_generator.py deleted file mode 100644 index fa17df3..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/icon_prompt_generator.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import Any, Dict, List, Optional -from langchain_openai import ChatOpenAI -from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage -from langchain_core.tools import Tool - -from dataflow_agent.promptstemplates.prompt_template import PromptsTemplateGenerator -from dataflow_agent.state import MainState -from dataflow_agent.utils import robust_parse_json -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -from dataflow_agent.agentroles.cores.base_agent import BaseAgent - -@register("icon_prompt_generator") -class IconPromptGenerator(BaseAgent): - """图标提示词生成器 - 根据用户关键词生成text2img prompt""" - - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - @property - def role_name(self) -> str: - return "icon_prompt_generator" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_icon_prompt_generation" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_icon_prompt_generation" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """图标提示词生成器特有的提示词参数""" - return { - 'user_keywords': pre_tool_results.get('keywords', ''), - 'style_preferences': pre_tool_results.get('style', 'kartoon, minimalist'), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """图标提示词生成器的默认前置工具结果""" - return { - 'keywords': '', - 'style': 'kartoon, minimalist' - } - - def update_state_result(self, state: MainState, result: Dict[str, Any], pre_tool_results: Dict[str, Any]): - """自定义状态更新 - 保存生成的prompt""" - state.icon_prompt = result.get('icon_prompt', result) if isinstance(result, dict) else result - super().update_state_result(state, result, pre_tool_results) - - -async def icon_prompt_generation( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.7, - max_tokens: int = 512, - use_agent: bool = False, - **kwargs, -) -> MainState: - """生成图标提示词的入口函数""" - generator = IconPromptGenerator( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - ) - return await generator.execute(state, use_agent=use_agent, **kwargs) - - -def create_icon_prompt_generator(tool_manager: Optional[ToolManager] = None, **kwargs) -> IconPromptGenerator: - """创建图标提示词生成器实例""" - return IconPromptGenerator(tool_manager=tool_manager, **kwargs) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/image_filter_agent.py b/dataflow_agent/agentroles/paper2any_agents/image_filter_agent.py deleted file mode 100644 index dfec353..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/image_filter_agent.py +++ /dev/null @@ -1,111 +0,0 @@ -""" -Image Filter Agent -""" -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - - -@register("image_filter_agent") -class ImageFilterAgent(BaseAgent): - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - @property - def role_name(self) -> str: - return "image_filter_agent" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_image_filter_agent" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_image_filter_agent" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "query": pre_tool_results.get("query", ""), - "image_items_json": pre_tool_results.get("image_items_json", "[]"), - } - - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - state.filtered_image_items = result.get("selected_items", []) - log.info(f"[image_filter_agent]: selected {len(state.filtered_image_items)} images") - super().update_state_result(state, result, pre_tool_results) - - -async def image_filter_agent( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - agent = ImageFilterAgent( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_image_filter_agent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> ImageFilterAgent: - return ImageFilterAgent.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/kb_image_insert_agent.py b/dataflow_agent/agentroles/paper2any_agents/kb_image_insert_agent.py deleted file mode 100644 index 53c8f07..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/kb_image_insert_agent.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -KB Image Insert Agent -""" -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - - -@register("kb_image_insert_agent") -class KBImageInsertAgent(BaseAgent): - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - @property - def role_name(self) -> str: - return "kb_image_insert_agent" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_kb_image_insert_agent" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_kb_image_insert_agent" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "pagecontent_json": pre_tool_results.get("pagecontent_json", "[]"), - "image_items_json": pre_tool_results.get("image_items_json", "[]"), - } - - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - current_pagecontent = getattr(state, "pagecontent", []) or [] - new_pagecontent = result.get("pagecontent", None) - if isinstance(new_pagecontent, list) and new_pagecontent: - state.pagecontent = new_pagecontent - elif new_pagecontent is None: - state.pagecontent = current_pagecontent - else: - # 如果插图结果为空,保留原有 pagecontent,避免生成 0 页 - state.pagecontent = current_pagecontent - log.info(f"[kb_image_insert_agent]: pagecontent size {len(state.pagecontent)}") - super().update_state_result(state, result, pre_tool_results) - - -async def kb_image_insert_agent( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - agent = KBImageInsertAgent( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_kb_image_insert_agent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> KBImageInsertAgent: - return KBImageInsertAgent.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/long_paper_outline_agent.py b/dataflow_agent/agentroles/paper2any_agents/long_paper_outline_agent.py deleted file mode 100644 index 5b2e66b..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/long_paper_outline_agent.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -LongPaperOutlineAgent agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Description: 专门用于处理长文档分批生成 PPT 大纲的 Agent。 -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("long_paper_outline_agent") -class LongPaperOutlineAgent(BaseAgent): - """ - LongPaperOutlineAgent: 负责接收分批次的长文本,生成对应的 PPT 大纲页面。 - """ - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: - return "long_paper_outline_agent" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_outline_agent" - - @property - def task_prompt_template_name(self) -> str: - if getattr(self.state, "is_first", False): - return "task_prompt_for_long_paper_outline_agent_first" - if getattr(self.state, "is_last", False): - return "task_prompt_for_long_paper_outline_agent_last" - return "task_prompt_for_long_paper_outline_agent_middle" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """ - 构造 Prompt 参数。 - 需要 Workflow 传入: - - current_chunk: 当前批次的文本内容 - - batch_info: 批次信息 (index, total, etc.) - """ - batch_info = pre_tool_results.get("batch_info", {}) - - return { - "current_chunk": self.state.current_chunk, - "batch_index": batch_info.get("batch_index", 1), - "total_batches": batch_info.get("total_batches", 1), - "pages_to_generate": batch_info.get("pages_to_generate", 10), - "is_first": batch_info.get("is_first", False), - "is_last": batch_info.get("is_last", False), - "page_count" : self.state.request.page_count, - "language": self.state.request.language, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "current_chunk": "", - "batch_info": {} - } - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """ - 将生成的结果(JSON List)写回 State。 - 注意:在 Workflow 的 generate_outline_for_batch 中, - 会从返回的 State 中读取 pagecontent。 - """ - # 结果预期是一个 List[Dict] (页面列表) - state.pagecontent = result - log.info(f"[long_paper_outline_agent] 生成了 {len(result) if isinstance(result, list) else 0} 页内容") - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def long_paper_outline_agent( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - agent = LongPaperOutlineAgent( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_long_paper_outline_agent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> LongPaperOutlineAgent: - return LongPaperOutlineAgent.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/outline_agent.py b/dataflow_agent/agentroles/paper2any_agents/outline_agent.py deleted file mode 100644 index 5068f0b..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/outline_agent.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -OutlineAgent agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-17 19:51:17 -生成位置: dataflow_agent/agentroles/common_agents/outline_agent_agent.py - -本文件由 `dfa create --agent_name outline_agent` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("outline_agent") -class OutlineAgent(BaseAgent): - """TODO: 描述 outline_agent 的职责""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "outline_agent" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_outline_agent" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_outline_agent" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数 - 提示词中的占位符: - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - """ - # TODO: 按需补充 - return { - "minueru_output": pre_tool_results.get("minueru_output", ""), - "text_content": pre_tool_results.get("text_content", ""), - "page_count" : self.state.request.page_count, - "language": self.state.request.language, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写""" - - state.pagecontent = result - log.info(f"[outline_agent]: outline_agent 生成了 {len(result)} 页内容") - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def outline_agent( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """outline_agent 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = OutlineAgent( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_outline_agent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> OutlineAgent: - return OutlineAgent.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/outline_refine_agent.py b/dataflow_agent/agentroles/paper2any_agents/outline_refine_agent.py deleted file mode 100644 index 0217068..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/outline_refine_agent.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -OutlineRefineAgent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Refines an existing PPT outline based on user feedback. -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - - -@register("outline_refine_agent") -class OutlineRefineAgent(BaseAgent): - """Refine existing outline content while keeping page order and count.""" - - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - @property - def role_name(self) -> str: # noqa: D401 - return "outline_refine_agent" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_outline_refine_agent" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_outline_refine_agent" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "pagecontent": pre_tool_results.get("pagecontent", "[]"), - "outline_feedback": pre_tool_results.get("outline_feedback", ""), - "minueru_output": pre_tool_results.get("minueru_output", ""), - "text_content": pre_tool_results.get("text_content", ""), - "page_count": self.state.request.page_count, - "language": self.state.request.language, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return {} - - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - original = pre_tool_results.get("pagecontent_raw") - if not isinstance(original, list): - original = getattr(state, "pagecontent", []) or [] - - if not isinstance(result, list): - log.warning("[outline_refine_agent] Invalid result, fallback to original pagecontent.") - state.pagecontent = original - super().update_state_result(state, original, pre_tool_results) - return - - merged_pages = [] - for item in result: - if isinstance(item, dict): - merged = item.copy() - merged_pages.append(merged) - - state.pagecontent = merged_pages - log.info(f"[outline_refine_agent] refined {len(merged_pages)} pages") - super().update_state_result(state, merged_pages, pre_tool_results) - - -async def outline_refine_agent( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """Async entry for outline_refine_agent.""" - agent = OutlineRefineAgent( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_outline_refine_agent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> OutlineRefineAgent: - return OutlineRefineAgent.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/p2v_beamer_code_debug_agent.py b/dataflow_agent/agentroles/paper2any_agents/p2v_beamer_code_debug_agent.py deleted file mode 100644 index afa15a6..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/p2v_beamer_code_debug_agent.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -P2vBeamerCodeDebug agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-11-29 15:41:18 -生成位置: dataflow_agent/agentroles/common_agents/p2v_beamer_code_debug_agent.py - -本文件由 `dfa create --agent_name p2v_beamer_code_debug` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("p2v_beamer_code_debug") -class P2vBeamerCodeDebug(BaseAgent): - """用来debug得到的beamer代码""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "p2v_beamer_code_debug" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_p2v_beamer_code_debug" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_p2v_beamer_code_debug" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数""" - return { - "beamer_code": pre_tool_results.get("beamer_code", ""), - "code_debug_result": pre_tool_results.get("code_debug_result", ""), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果 {latex_code: xxxx} 写回 MainState""" - beamer_code = result.get("latex_code", '') - beamer_code_path = state.beamer_code_path - if beamer_code and beamer_code_path: - from pathlib import Path - - tex_path = Path(beamer_code_path) - tex_path.write_text(beamer_code, encoding='utf-8') - # 编译最新的tex代码 - from dataflow_agent.toolkits.p2vtool.p2v_tool import compile_tex - is_beamer_wrong, is_beamer_warning, code_debug_result = compile_tex(beamer_code_path) - state.ppt_path = beamer_code_path.replace(".tex", ".pdf") - log.info(f"将更新好的beamer code写回 {beamer_code_path}") - else: - log.error(f"Failed to update beamer code: missing latex_code or beamer_code_path") - - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def p2v_beamer_code_debug( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """p2v_beamer_code_debug 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = P2vBeamerCodeDebug( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_p2v_beamer_code_debug( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> P2vBeamerCodeDebug: - return P2vBeamerCodeDebug.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/p2v_extract_pdf_agent.py b/dataflow_agent/agentroles/paper2any_agents/p2v_extract_pdf_agent.py deleted file mode 100644 index 55e21bd..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/p2v_extract_pdf_agent.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -P2vExtractPdf agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-11-26 20:06:54 - -本文件由 `dfa create --agent_name p2v_extract_pdf` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.toolkits.p2vtool.p2v_tool import extract_beamer_code -from pathlib import Path - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("p2v_extract_pdf") -class P2vExtractPdf(BaseAgent): - """ - 从pdf中提取图片,文字,公式,列表 - """ - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "p2v_extract_pdf" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_p2v_extract_pdf" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_p2v_extract_pdf" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数""" - return { - "pdf_markdown": pre_tool_results.get("pdf_markdown", ""), - "pdf_images_working_dir": pre_tool_results.get("pdf_images_working_dir", ""), - "output_language": pre_tool_results.get("output_language", "English") - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写""" - try: - raw_beamer_code = result.get("latex_code", "") if isinstance(result, dict) else "" - beamer_code = "" - if isinstance(raw_beamer_code, str): - # state.beamer_code_path = extract_beamer_code(beamer_code) - beamer_code = extract_beamer_code(raw_beamer_code) - - pdf_path = Path(state.request.get("paper_pdf_path", "")) - beamer_code_path = pdf_path.with_suffix('').expanduser().resolve() / "auto/beamer_code.tex" - state.beamer_code_path = str(beamer_code_path) - beamer_code_path.write_text(beamer_code, encoding='utf-8') - log.info(f"获得到了beamer code,内容保存在 {beamer_code_path}") - else: - log.error(f"无法处理的 latex_code 类型: {type(raw_beamer_code)}") - except Exception: - pass - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def p2v_extract_pdf( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """p2v_extract_pdf 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = P2vExtractPdf( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_p2v_extract_pdf( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> P2vExtractPdf: - return P2vExtractPdf.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/p2v_pdf2ppt_agent.py b/dataflow_agent/agentroles/paper2any_agents/p2v_pdf2ppt_agent.py deleted file mode 100644 index afbe3d1..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/p2v_pdf2ppt_agent.py +++ /dev/null @@ -1,165 +0,0 @@ -""" -P2vPdf2ppt agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-11-29 15:42:01 -生成位置: dataflow_agent/agentroles/common_agents/p2v_pdf2ppt_agent.py - -本文件由 `dfa create --agent_name p2v_pdf2ppt` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("p2v_pdf2ppt") -class P2vPdf2ppt(BaseAgent): - """TODO: 描述 p2v_pdf2ppt 的职责""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "p2v_pdf2ppt" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_p2v_pdf2ppt" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_p2v_pdf2ppt" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数 - 提示词中的占位符: - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - """ - # TODO: 按需补充 - return {} - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写""" - - state.xx = result - - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def p2v_pdf2ppt( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """p2v_pdf2ppt 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = P2vPdf2ppt( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_p2v_pdf2ppt( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> P2vPdf2ppt: - return P2vPdf2ppt.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/p2v_subtitle_and_cursor_agent.py b/dataflow_agent/agentroles/paper2any_agents/p2v_subtitle_and_cursor_agent.py deleted file mode 100644 index cb4ce90..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/p2v_subtitle_and_cursor_agent.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -P2vSubtitleAndCursor agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-10 15:38:41 -生成位置: dataflow_agent/agentroles/common_agents/p2v_subtitle_and_cursor_agent.py - -本文件由 `dfa create --agent_name p2v_subtitle_and_cursor` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("p2v_subtitle_and_cursor") -class P2vSubtitleAndCursor(BaseAgent): - """为每张幻灯片生成配音字幕,并规划光标移动提示(配合讲解逻辑)""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "p2v_subtitle_and_cursor" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_p2v_subtitle_and_cursor" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_p2v_subtitle_and_cursor" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - # prompt中不需要占位符 - return {} - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写""" - - # state.subtitle_and_cursor_path = result - subtitle_and_cursor_info = result.get("subtitle_and_cursor", "") - log.info(f"获取了单张slide的Subtitle and Cursor 信息: {subtitle_and_cursor_info}") - state.subtitle_and_cursor.append(subtitle_and_cursor_info) - - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def p2v_subtitle_and_cursor( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """p2v_subtitle_and_cursor 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = P2vSubtitleAndCursor( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_p2v_subtitle_and_cursor( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> P2vSubtitleAndCursor: - return P2vSubtitleAndCursor.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/paper_idea_extractor.py b/dataflow_agent/agentroles/paper2any_agents/paper_idea_extractor.py deleted file mode 100644 index 176e109..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/paper_idea_extractor.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations -from typing import Any, Dict, Optional -from dataflow_agent.state import DFState, Paper2FigureState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.agentroles.cores.base_agent import BaseAgent - -log = get_logger(__name__) - -@register("paper_idea_extractor") -class PaperIdeaExtractor(BaseAgent): - @property - def role_name(self) -> str: - return "paper_idea_extractor" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_paper_idea_extractor" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_paper_idea_extractor" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - paper_content = pre_tool_results.get("paper_content") - return { - "paper_content": paper_content, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "paper_content": "", - } - - def update_state_result(self, state: Paper2FigureState, result: Dict[str, Any], pre_tool_results: Dict[str, Any]): - - state.paper_idea = result.get("paper_idea", "") if result.get("paper_idea", "") != "" else result - - return super().update_state_result(state, result, pre_tool_results) - - -# Function to extract paper ideas -async def paper_idea_extractor( - state: Paper2FigureState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 2048, - use_agent: bool = False, - **kwargs, -) -> Paper2FigureState: - inst = create_paper_idea_extractor( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - ) - return await inst.execute(state, use_agent=use_agent, **kwargs) - - -def create_paper_idea_extractor(tool_manager: Optional[ToolManager] = None, **kwargs) -> PaperIdeaExtractor: - if tool_manager is None: - from dataflow_agent.toolkits.tool_manager import get_tool_manager - tool_manager = get_tool_manager() - return PaperIdeaExtractor(tool_manager=tool_manager, **kwargs) diff --git a/dataflow_agent/agentroles/paper2any_agents/qa_agent.py b/dataflow_agent/agentroles/paper2any_agents/qa_agent.py deleted file mode 100644 index 6909057..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/qa_agent.py +++ /dev/null @@ -1,165 +0,0 @@ -""" -QaAgent agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2026-01-21 18:15:51 -生成位置: dataflow_agent/agentroles/common_agents/qa_agent_agent.py - -本文件由 `dfa create --agent_name qa_agent` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("qa_agent") -class QaAgent(BaseAgent): - """TODO: 描述 qa_agent 的职责""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "qa_agent" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_qa_agent" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_qa_agent" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数 - 提示词中的占位符: - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - """ - # TODO: 按需补充 - return {} - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写""" - - state.xx = result - - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def qa_agent( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """qa_agent 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = QaAgent( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_qa_agent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> QaAgent: - return QaAgent.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) \ No newline at end of file diff --git a/dataflow_agent/agentroles/paper2any_agents/svg_bg_cleaner_agent.py b/dataflow_agent/agentroles/paper2any_agents/svg_bg_cleaner_agent.py deleted file mode 100644 index f03fdf4..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/svg_bg_cleaner_agent.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -SvgBgCleaner agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -参考 TechnicalRouteDescGenerator,实现对 SVG 代码的“去文本”清洗。 - -职责: -- 从 MainState 中读取原始 SVG 源码(通常在 figure_tec_svg_content 字段); -- 调用 LLM,根据 prompt 清洗掉所有文本元素,仅保留图形相关元素; -- 将结果写入 state.agent_results["svg_bg_cleaner"]["svg_bg_code"],供 workflow 使用; -- 提供严格的 ReAct 验证器: - 1. JSON 结构正确,包含 svg_bg_code; - 2. svg_bg_code 是合法 SVG(XML 解析通过); - 3. svg_bg_code 中不再包含 标签。 -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional, List, Tuple - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - - -@register("svg_bg_cleaner") -class SvgBgCleaner(BaseAgent): - """对 SVG 代码进行“去文本”清洗的 Agent""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "svg_bg_cleaner" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_svg_bg_cleaner" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_svg_bg_cleaner" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """ - 将当前 state 中的原始 SVG 代码传入 prompt。 - 约定: - - 上游已将 SVG 源码写入 state.figure_tec_svg_content。 - """ - return { - "svg_code": self.state.figure_tec_svg_content - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return {} - - # ---------- ReAct 验证器 ---------- - def get_react_validators(self) -> List: - """ - 验证器签名: - validator(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str] - - 这里对 parsed_result["svg_bg_code"] 做两类检查: - 1) XML 层面:是否为合法 SVG; - 2) 语义层面:是否不再包含文本相关标签(text / tspan / title)。 - """ - - def _extract_svg_fragment(svg_code: str) -> str: - """ - 从模型返回的字符串中,提取出干净的 <svg>...</svg> 片段。 - - 步骤: - 1. 去掉首尾空白; - 2. 去掉可能的 ``` / ```svg 代码块包裹; - 3. 截取第一个 <svg ...> 到最后一个 </svg> 之间的部分。 - """ - if not svg_code: - return "" - - text = svg_code.strip() - - # 处理 ```svg ... ``` 或 ``` ... ``` 代码块 - if text.startswith("```"): - lines = [line for line in text.splitlines() if line.strip("`").strip()] - for i, line in enumerate(lines): - if "<svg" in line or "<SVG" in line or "<Svg" in line or "<svg" in line or "<SVG" in line: - text = "\n".join(lines[i:]) - break - - # 既兼容已转义的 <svg>,也兼容原始 <svg> - candidates = ["<svg", "<SVG", "<Svg", "<svg", "<SVG", "<Svg"] - start = -1 - for c in candidates: - start = text.find(c) - if start != -1: - break - - end = -1 - for c in ["</svg>", "</svg>"]: - pos = text.rfind(c) - if pos != -1: - end = pos + len(c) - break - - if start == -1 or end == -1: - return text - - return text[start:end].strip() - - def validate_svg_bg(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """ - SvgBgCleaner 的 SVG 验证器: - 1. 检查 parsed_result 中存在 svg_bg_code; - 2. 检查 svg_bg_code XML well-formed; - 3. 检查不包含 text/tspan/title 标签(大小写不敏感)。 - """ - import re - import xml.etree.ElementTree as ET - - if not isinstance(parsed_result, dict): - return False, "解析结果不是字典,无法找到 svg_bg_code 字段" - - svg_bg_code = parsed_result.get("svg_bg_code") - if not svg_bg_code: - return False, "缺少 svg_bg_code 字段或内容为空" - - fragment = _extract_svg_fragment(svg_bg_code) - if "svg" not in fragment.lower(): - return False, "返回内容中未检测到 <svg> 根标签" - - # 1) XML 合法性检查 - try: - # 如果是 HTML 实体转义过的 <svg> 形式,先简单还原 - xml_text = fragment.replace("<", "<").replace(">", ">") - ET.fromstring(xml_text) - except Exception as e: - log.warning(f"svg_bg_cleaner.validate_svg_bg: SVG XML 解析失败: {e}") - return False, f"svg_bg_code 不是合法 XML: {e}" - - # 2) 文本标签检查:不应再包含 text / tspan / title - lowered = fragment.lower() - text_like = ["<text", "<text", "<tspan", "<tspan", "<title", "<title"] - if any(t in lowered for t in text_like): - return False, "svg_bg_code 中仍包含 text/tspan/title 文本相关标签" - - return True, "SVG 背景清洗结果验证通过" - - return [validate_svg_bg] - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """ - 期望 LLM 返回: - {"svg_bg_code": "<svg ...>...</svg>"} - """ - svg_bg_code = None - if isinstance(result, dict): - svg_bg_code = result.get("svg_bg_code") - - state.svg_bg_code = svg_bg_code - - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def svg_bg_cleaner( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """ - svg_bg_cleaner 的异步入口 - """ - agent = SvgBgCleaner( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_svg_bg_cleaner( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> SvgBgCleaner: - """ - 工厂函数,便于在其他模块中创建 svg_bg_cleaner agent。 - """ - return SvgBgCleaner.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/table_extractor_agent.py b/dataflow_agent/agentroles/paper2any_agents/table_extractor_agent.py deleted file mode 100644 index 5ffefa7..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/table_extractor_agent.py +++ /dev/null @@ -1,332 +0,0 @@ -""" -TableExtractor agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-17 22:53:23 -生成位置: dataflow_agent/agentroles/common_agents/table_extractor_agent.py - -本文件由 `dfa create --agent_name table_extractor` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional, List, Tuple - -from pathlib import Path -import re - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent, ValidatorFunc -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.utils import get_project_root - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("table_extractor") -class TableExtractor(BaseAgent): - """从 MinerU 输出中定位指定表格并生成 HTML(LLM),并将 HTML 渲染为 PNG 写入 state.table_img_path""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "table_extractor" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_table_extractor" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_table_extractor" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数 - 提示词中的占位符: - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - """ - # TODO: 按需补充 - return { - 'minueru_output': self.state.minueru_output, - 'table_num': self.state.asset_ref, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {"table_num": ""} - - # ---------- ReAct Validators ---------- - def get_react_validators(self) -> List[ValidatorFunc]: - return [ - self._default_json_validator, - self._validator_has_html_code, - self._validator_html_not_markdown, - self._validator_html_has_table_tag, - self._validator_html_renderable, # 新增渲染验证器 - ] - - def _validator_html_renderable(self, content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, Optional[str]]: - """ - 尝试渲染 HTML,如果失败则反馈给 LLM 让其重试/简化。 - """ - html_code = (parsed_result.get("html_code") or "") if isinstance(parsed_result, dict) else "" - if not html_code: - return True, None # 前面的验证器会拦截空内容,这里跳过 - - # 构造完整文档 - full_html = self._wrap_html_document(html_code) - - # 使用临时文件测试渲染 - import tempfile - import os - - try: - with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: - tmp_path = tmp.name - - # 尝试渲染 - self._render_html_to_png(full_html, tmp_path) - - # 检查文件是否生成且非空 - if not os.path.exists(tmp_path) or os.path.getsize(tmp_path) == 0: - return False, "HTML 代码看起来正确,但渲染引擎无法生成图片。请尝试简化 HTML 结构,去除复杂的 CSS 或特殊字符,只保留最基本的表格结构。" - - # 渲染成功,清理临时文件 - try: - os.remove(tmp_path) - except Exception: - pass - - return True, None - - except Exception as e: - err_msg = str(e) - # 提取关键错误信息反馈给 LLM - if "Could not write to output file" in err_msg or "Could not save image" in err_msg: - return False, f"HTML 渲染失败 (System Error)。请尝试极大简化 HTML 代码,不要使用任何外部资源引用,不要使用复杂的样式,确保是一个标准的、最简的 HTML 表格。" - - return False, f"HTML 渲染抛出异常: {err_msg[:200]}... 请检查代码是否包含导致渲染引擎崩溃的非法结构。" - - @staticmethod - def _validator_has_html_code(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, Optional[str]]: - if not isinstance(parsed_result, dict): - return False, "返回结果必须是 JSON 对象。" - if "html_code" not in parsed_result: - return False, '缺少字段 "html_code",请按格式返回:{"html_code":"..."}' - if not isinstance(parsed_result.get("html_code"), str) or not parsed_result.get("html_code", "").strip(): - return False, '"html_code" 必须是非空字符串。' - return True, None - - @staticmethod - def _validator_html_not_markdown(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, Optional[str]]: - html = (parsed_result.get("html_code") or "") if isinstance(parsed_result, dict) else "" - if "```" in html or "```" in content: - return False, "不要输出 markdown 代码块标记(```),只返回纯 JSON。" - return True, None - - @staticmethod - def _validator_html_has_table_tag(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, Optional[str]]: - html = (parsed_result.get("html_code") or "") if isinstance(parsed_result, dict) else "" - if "<table" not in html.lower(): - return False, 'html_code 中必须包含 <table ...> 表格结构。' - if "</table>" not in html.lower(): - return False, 'html_code 中必须包含 </table> 闭合标签。' - return True, None - - # ---------- Render helpers ---------- - @staticmethod - def _normalize_table_num(raw: Any) -> str: - """ - 兼容: - - "Table 2" / "table_2" / "2" - 输出统一 key:table_2 - """ - if raw is None: - return "" - s = str(raw).strip() - if not s: - return "" - m = re.search(r"(\d+)", s) - if m: - return f"table_{m.group(1)}" - return s.lower().replace(" ", "_") - - @staticmethod - def _wrap_html_document(table_html: str) -> str: - """ - 兜底包装成完整 HTML 文档,保证 wkhtmltoimage 渲染稳定。 - """ - css = """ - <style> - body { font-family: Arial, Helvetica, sans-serif; margin: 20px; } - table { border-collapse: collapse; width: 100%; font-size: 14px; } - th, td { border: 1px solid #333; padding: 6px 8px; vertical-align: top; } - caption { caption-side: top; font-weight: 700; margin-bottom: 8px; } - </style> - """.strip() - return f"<!doctype html><html><head><meta charset='utf-8'>{css}</head><body>{table_html}</body></html>" - - def _render_html_to_png(self, html_content: str, save_path: str) -> None: - """ - 使用 imgkit(wkhtmltoimage) 渲染 HTML -> PNG。 - """ - import imgkit - - options = { - "format": "png", - "encoding": "UTF-8", - "quality": "100", - "enable-local-file-access": "", - } - imgkit.from_string(html_content, save_path, options=options) - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,并将 html_code 渲染成图片写入 state.table_img_path""" - super().update_state_result(state, result, pre_tool_results) - - if not isinstance(result, dict): - return - - html_code = str(result.get("html_code") or "").strip() - if not html_code: - return - - # 输出目录:优先 state.result_path,其次项目 outputs/table_extractor - base_dir = getattr(state, "result_path", "") or "" - if base_dir: - out_dir = Path(base_dir) / "tables" - else: - out_dir = get_project_root() / "outputs" / "table_extractor" - out_dir.mkdir(parents=True, exist_ok=True) - - table_key = self._normalize_table_num(self.state.asset_ref) - - file_name = f"{table_key}.png" if table_key else "table.png" - - png_path = str((out_dir / file_name).resolve()) - - html_content = self._wrap_html_document(html_code) - - try: - self._render_html_to_png(html_content, png_path) - - state.table_img_path = png_path - - log.critical(f'[table_img_path 表格图像路径]: {png_path}') - - # 同步到 result,方便下游直接读 - result["table_img_path"] = png_path - except Exception as e: - # 不抛异常,避免影响主流程;把错误写回结果 - log.error(f"渲染表格图片失败: {e}", exc_info=True) - result["render_error"] = str(e) - - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def table_extractor( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """table_extractor 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = TableExtractor( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_table_extractor( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> TableExtractor: - return TableExtractor.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/table_text_renderer.py b/dataflow_agent/agentroles/paper2any_agents/table_text_renderer.py deleted file mode 100644 index 2ccd6ee..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/table_text_renderer.py +++ /dev/null @@ -1,412 +0,0 @@ -""" -TableTextRenderer agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -从表格文本生成 matplotlib 代码并渲染为表格图片 - -支持复杂表格结构: -- 多级表头(跨行/跨列合并) -- LaTeX 表格格式 -- Markdown 表格格式 -- CSV/TSV 格式 - -用于 Paper2ExpFigure 工作流的 TEXT 模式 -""" - -from __future__ import annotations - -import json -from pathlib import Path -from typing import Any, Dict, List, Optional - -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.registry import register -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.utils import execute_matplotlib_code - -log = get_logger(__name__) - - -@register("table_text_renderer") -class TableTextRenderer(BaseAgent): - """从表格文本生成渲染代码的 Agent""" - - @property - def role_name(self) -> str: - return "table_text_renderer" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_table_text_renderer" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_table_text_renderer" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """从 pre_tool_results 中获取 prompt 参数""" - return { - "table_text": pre_tool_results.get("table_text", ""), - "table_title": pre_tool_results.get("table_title", ""), - "output_path": pre_tool_results.get("output_path", "table_output.png"), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """默认的 pre_tool_results""" - return { - "table_text": "", - "table_title": "", - "output_path": "table_output.png", - } - - def update_state_result( - self, - state, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将生成的代码和解析结果写入 state""" - try: - if isinstance(result, dict): - code = result.get("code", "") - table_structure = result.get("table_structure", {}) - - # 存储到 temp_data - if not hasattr(state, 'temp_data') or state.temp_data is None: - state.temp_data = {} - - state.temp_data["table_render_code"] = code - state.temp_data["table_structure"] = table_structure - - log.info(f"[TableTextRenderer] 生成代码长度: {len(code)}") - log.info(f"[TableTextRenderer] 多级表头: {table_structure.get('has_multi_level_header', False)}") - except Exception as e: - log.warning(f"[TableTextRenderer] 更新 state 失败: {e}") - - return super().update_state_result(state, result, pre_tool_results) - - async def execute_pre_tools(self, state) -> Dict[str, Any]: - """重写 execute_pre_tools,从 state.pre_tool_results 注入参数""" - results = await super().execute_pre_tools(state) - - # 从 state.pre_tool_results 注入参数 - inject_results = getattr(state, 'pre_tool_results', {}) - for key, value in inject_results.items(): - if value: - results[key] = value - - return results - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -def create_table_text_renderer( - tool_manager: Optional[ToolManager] = None, - **kwargs, -) -> TableTextRenderer: - """创建 TableTextRenderer 实例""" - if tool_manager is None: - from dataflow_agent.toolkits.tool_manager import get_tool_manager - tool_manager = get_tool_manager() - return TableTextRenderer(tool_manager=tool_manager, **kwargs) - - -async def render_table_from_text( - table_text: str, - output_path: Path, - state, - title: str = "", - model_name: str = "deepseek-v3.2", - tool_manager: Optional[ToolManager] = None, -) -> tuple: - """ - 从表格文本渲染表格图片的完整流程 - - Args: - table_text: 原始表格文本(支持 LaTeX/Markdown/CSV 等格式) - output_path: 输出图片路径 - state: 状态对象 - title: 表格标题 - model_name: 使用的模型名称 - tool_manager: 工具管理器 - - Returns: - (success, parsed_data): 是否成功,以及解析出的表格结构数据 - """ - from dataflow_agent.agentroles import create_simple_agent - - parsed_data = { - "headers": [], - "rows": [], - "has_multi_level_header": False, - "header_levels": 1, - } - - try: - # 创建 agent - agent = create_simple_agent( - name="table_text_renderer", - model_name=model_name, - temperature=0.1, - max_tokens=4096, - tool_manager=tool_manager, - ) - - # 注入 pre_tool_results - state.pre_tool_results = { - "table_text": table_text, - "table_title": title, - "output_path": str(output_path), - } - - # 执行 agent - state = await agent.execute(state=state, use_agent=False) - - # 获取结果 - agent_result = state.agent_results.get("table_text_renderer", {}).get("results", {}) - code = agent_result.get("code", "") - table_structure = agent_result.get("table_structure", {}) - - if table_structure: - parsed_data.update(table_structure) - - if not code: - log.warning("[render_table_from_text] Agent 未返回代码,使用回退方案") - return _render_table_fallback(table_text, output_path, title), parsed_data - - log.info(f"[render_table_from_text] 生成代码长度: {len(code)} 字符") - - # 在代码前添加 matplotlib 后端设置 - full_code = f''' -import matplotlib -matplotlib.use('Agg') - -{code} -''' - - # 执行代码 - result = execute_matplotlib_code( - code=full_code, - output_path=output_path, - timeout=30, - ) - - if result['success']: - log.info(f"[render_table_from_text] 表格图片已生成: {output_path}") - return True, parsed_data - else: - log.warning(f"[render_table_from_text] 代码执行失败: {result['error']}") - return _render_table_fallback(table_text, output_path, title), parsed_data - - except Exception as e: - log.error(f"[render_table_from_text] 渲染失败: {e}") - import traceback - traceback.print_exc() - return _render_table_fallback(table_text, output_path, title), parsed_data - - -def _render_table_fallback( - table_text: str, - output_path: Path, - title: str = "" -) -> bool: - """ - 回退方案:简单解析表格文本并用 matplotlib 渲染 - """ - import matplotlib.pyplot as plt - import matplotlib - matplotlib.use('Agg') - - try: - # 简单解析:按行分割,按常见分隔符分列 - lines = [l.strip() for l in table_text.strip().split('\n') if l.strip()] - if not lines: - return False - - # 跳过 LaTeX 命令行和 markdown 分隔行 - filtered_lines = [] - for l in lines: - # 跳过 LaTeX 命令 - if l.startswith('\\') and not l.startswith('\\hline'): - continue - # 跳过 markdown 分隔行 - if all(c in '-|: ' for c in l): - continue - # 跳过 \hline - if l == '\\hline': - continue - filtered_lines.append(l) - - lines = filtered_lines - if len(lines) < 2: - return False - - # 检测分隔符 - first_line = lines[0] - if '&' in first_line: # LaTeX - sep = '&' - elif '|' in first_line: # Markdown - sep = '|' - elif '\t' in first_line: # TSV - sep = '\t' - elif ',' in first_line: # CSV - sep = ',' - else: - sep = None - - if sep: - headers = [c.strip().replace('\\\\', '').strip() for c in first_line.split(sep) if c.strip()] - rows = [] - for l in lines[1:]: - row = [c.strip().replace('\\\\', '').strip() for c in l.split(sep) if c.strip()] - if row: - rows.append(row) - else: - import re - headers = re.split(r'\s{2,}', first_line) - rows = [re.split(r'\s{2,}', l) for l in lines[1:]] - - if not headers or not rows: - return False - - # 设置中文字体 - plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif'] - plt.rcParams['axes.unicode_minus'] = False - - fig, ax = plt.subplots(figsize=(max(len(headers) * 1.5, 8), max(len(rows) * 0.5 + 1, 4))) - ax.axis('off') - - if title: - ax.set_title(title, fontsize=14, fontweight='bold', pad=20) - - # 确保所有行的列数一致 - max_cols = len(headers) - normalized_rows = [] - for row in rows: - if len(row) < max_cols: - row = row + [''] * (max_cols - len(row)) - elif len(row) > max_cols: - row = row[:max_cols] - normalized_rows.append(row) - - table = ax.table( - cellText=normalized_rows, - colLabels=headers, - loc='center', - cellLoc='center', - ) - - table.auto_set_font_size(False) - table.set_fontsize(10) - table.scale(1.2, 1.5) - - for j in range(len(headers)): - cell = table[(0, j)] - cell.set_facecolor('#4472C4') - cell.set_text_props(color='white', fontweight='bold') - - for i in range(1, len(normalized_rows) + 1): - for j in range(len(headers)): - try: - cell = table[(i, j)] - cell.set_facecolor('#D9E2F3' if i % 2 == 0 else 'white') - except: - pass - - plt.savefig(str(output_path), dpi=150, bbox_inches='tight', facecolor='white') - plt.close(fig) - - log.info(f"[_render_table_fallback] 表格图片已生成: {output_path}") - return True - - except Exception as e: - log.error(f"[_render_table_fallback] 生成表格图片失败: {e}") - return False - - -async def split_tables_from_text( - text: str, - state, - model_name: str = "deepseek-v3.2", -) -> List[Dict[str, str]]: - """ - 使用 LLM 分析文本,识别并分割多个表格 - - Args: - text: 包含一个或多个表格的文本 - state: 状态对象 - model_name: 使用的模型名称 - - Returns: - [{"text": "表格文本", "caption": "表格标题"}, ...] - """ - from dataflow_agent.agentroles import create_simple_agent - - try: - # 创建 agent - agent = create_simple_agent( - name="table_splitter", - model_name=model_name, - temperature=0.1, - max_tokens=4096, - ) - - # 注入 pre_tool_results - state.pre_tool_results = { - "input_text": text, - } - - # 执行 agent - state = await agent.execute(state=state, use_agent=False) - - # 获取结果 - agent_result = state.agent_results.get("table_splitter", {}).get("results", {}) - tables = agent_result.get("tables", []) - - if tables: - log.info(f"[split_tables_from_text] 识别到 {len(tables)} 个表格") - return tables - else: - log.warning("[split_tables_from_text] 未识别到表格,返回原文本") - return [{"text": text, "caption": ""}] - - except Exception as e: - log.error(f"[split_tables_from_text] 分割表格失败: {e}") - return [{"text": text, "caption": ""}] - - -@register("table_splitter") -class TableSplitter(BaseAgent): - """分割文本中多个表格的 Agent""" - - @property - def role_name(self) -> str: - return "table_splitter" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_table_splitter" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_table_splitter" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "input_text": pre_tool_results.get("input_text", ""), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "input_text": "", - } - - async def execute_pre_tools(self, state) -> Dict[str, Any]: - """重写 execute_pre_tools,从 state.pre_tool_results 注入参数""" - results = await super().execute_pre_tools(state) - inject_results = getattr(state, 'pre_tool_results', {}) - for key, value in inject_results.items(): - if value: - results[key] = value - return results diff --git a/dataflow_agent/agentroles/paper2any_agents/tech_route_reference_analyzer.py b/dataflow_agent/agentroles/paper2any_agents/tech_route_reference_analyzer.py deleted file mode 100644 index a971da7..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/tech_route_reference_analyzer.py +++ /dev/null @@ -1,141 +0,0 @@ -""" -TechRouteReferenceAnalyzer agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -使用 VLM 分析技术路线图参考图,提取布局、风格、配色等信息 -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional, List, Tuple - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - - -@register("tech_route_reference_analyzer") -class TechRouteReferenceAnalyzer(BaseAgent): - """使用 VLM 分析技术路线图参考图的 Agent""" - - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - @property - def role_name(self) -> str: - return "tech_route_reference_analyzer" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_tech_route_reference_analyzer" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_tech_route_reference_analyzer" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数""" - return { - "reference_image_path": pre_tool_results.get("reference_image_path", ""), - "lang": self.state.request.language if hasattr(self.state, "request") else "zh", - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return {"reference_image_path": ""} - - def get_react_validators(self) -> List: - """返回 ReAct 模式下使用的验证器列表""" - - def validate_svg_code(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证 SVG 代码的完整性""" - if not isinstance(parsed_result, dict): - return False, "返回结果不是有效的 JSON 对象" - - svg_code = parsed_result.get("svg_code") - if not svg_code: - return False, "缺少 svg_code 字段" - - if "<svg" not in svg_code.lower(): - return False, "svg_code 中未检测到 <svg> 标签" - - return True, "验证通过" - - return [validate_svg_code] - - def update_state_result( - self, - state: Paper2FigureState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将 VLM 生成的 SVG 代码写回 State""" - if isinstance(result, dict): - if not hasattr(state, "temp_data"): - state.temp_data = {} - svg_code = result.get("svg_code", "") - state.temp_data["reference_svg_code"] = svg_code - log.info(f"[TechRouteReferenceAnalyzer] 参考图 SVG 生成完成,长度: {len(svg_code)}") - super().update_state_result(state, result, pre_tool_results) - - -async def tech_route_reference_analyzer( - state: Paper2FigureState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - use_vlm: bool = True, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> Paper2FigureState: - """tech_route_reference_analyzer 的异步入口""" - agent = TechRouteReferenceAnalyzer( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, **kwargs) - - -def create_tech_route_reference_analyzer( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - use_vlm: bool = True, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> TechRouteReferenceAnalyzer: - """创建 TechRouteReferenceAnalyzer 实例""" - if tool_manager is None: - from dataflow_agent.toolkits.tool_manager import get_tool_manager - tool_manager = get_tool_manager() - - return TechRouteReferenceAnalyzer( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/technical_route_bw_svg_generator.py b/dataflow_agent/agentroles/paper2any_agents/technical_route_bw_svg_generator.py deleted file mode 100644 index c710fb8..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/technical_route_bw_svg_generator.py +++ /dev/null @@ -1,261 +0,0 @@ -""" -technical_route_bw_svg_generator agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -基于模板 PNG 生成黑白技术路线图 SVG -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional, List, Tuple - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - - -@register("technical_route_bw_svg_generator") -class TechnicalRouteBWSvgGenerator(BaseAgent): - """参考模板 PNG 生成黑白技术路线图 SVG 的 Agent""" - - @property - def role_name(self) -> str: - return "technical_route_bw_svg_generator" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_technical_route_bw_svg_generator" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_technical_route_bw_svg_generator" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "paper_idea": pre_tool_results.get("paper_idea", ""), - "template_svg_code": pre_tool_results.get("template_svg_code", ""), - "validation_feedback": pre_tool_results.get("validation_feedback", ""), - "lang": self.state.request.language, - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "paper_idea": "", - "template_svg_code": "", - "validation_feedback": "", - } - - def get_react_validators(self) -> List: - """返回 ReAct 模式下使用的验证器列表""" - - # 获取用户要求的语言(用于语言验证) - required_lang = getattr(getattr(self.state, "request", None), "language", "en").lower() - - def _extract_svg_fragment(svg_code: str) -> str: - """提取干净的 SVG 片段""" - if not svg_code: - return "" - text = svg_code.strip() - if text.startswith("```"): - lines = [line for line in text.splitlines() if line.strip("`").strip()] - for i, line in enumerate(lines): - if "<svg" in line or "<SVG" in line: - text = "\n".join(lines[i:]) - break - start = text.find("<svg") - end = text.rfind("</svg>") - if start == -1 or end == -1: - return text - end += len("</svg>") - return text[start:end].strip() - - def _inject_chinese_font_for_validation(svg_code: str) -> str: - """为验证注入中文字体(与 workflow 中的逻辑一致)""" - import re - - # 检查是否包含中文字符 - has_chinese = bool(re.search(r'[\u4e00-\u9fff]', svg_code)) - if not has_chinese: - return svg_code - - # 中文友好字体列表 - chinese_fonts = 'Noto Sans CJK SC, Microsoft YaHei, SimHei, SimSun, WenQuanYi Zen Hei, sans-serif' - - # 替换所有 font-family 属性 - svg_code = re.sub( - r'font-family="[^"]*"', - f'font-family="{chinese_fonts}"', - svg_code - ) - - # 注入全局样式 - idx = svg_code.find(">") - if idx != -1: - style_block = f""" - <style type="text/css"> - text, tspan {{ - font-family: {chinese_fonts} !important; - }} - </style> -""" - svg_code = svg_code[:idx + 1] + style_block + svg_code[idx + 1:] - - return svg_code - - def validate_svg_structure(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证 SVG 基本结构""" - import xml.etree.ElementTree as ET - - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return False, "缺少 svg_code 字段或内容为空" - - fragment = _extract_svg_fragment(svg_code) - if "<svg" not in fragment.lower(): - return False, "未检测到 <svg> 根标签" - - if "viewBox" not in fragment and "viewbox" not in fragment: - return False, "缺少 viewBox 属性,请添加 viewBox" - - try: - ET.fromstring(fragment) - except Exception as e: - return False, f"SVG XML 解析失败: {e}" - - return True, "" - - def validate_chinese_font(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证中文字体设置""" - import re - - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return True, "" # 如果没有 SVG,跳过此验证 - - # 检查是否包含中文字符 - has_chinese = bool(re.search(r'[\u4e00-\u9fff]', svg_code)) - if not has_chinese: - return True, "" # 没有中文,跳过验证 - - # 检查是否使用了不支持中文的字体 - bad_fonts = ["Arial", "Helvetica", "Times", "Courier"] - for font in bad_fonts: - if f'font-family="{font}"' in svg_code or f"font-family='{font}'" in svg_code: - return False, ( - f"检测到中文文本但使用了不支持中文的字体 {font}。" - f"请使用中文友好字体,如:'Noto Sans CJK SC', 'Microsoft YaHei', 'SimHei', 'SimSun', sans-serif" - ) - - return True, "" - - def validate_language_requirement(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证 SVG 文本语言是否符合用户要求""" - import re - - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return True, "" # 如果没有 SVG,跳过此验证 - - # 检查是否包含中文字符 - has_chinese = bool(re.search(r'[\u4e00-\u9fff]', svg_code)) - - # 如果用户要求英文,但 SVG 包含中文,则验证失败 - if required_lang in ["en", "english"] and has_chinese: - return False, ( - f"用户要求使用英文(language={required_lang}),但生成的 SVG 包含中文字符。" - f"请严格按照用户要求的语言生成 SVG,所有文本标签必须使用英文。" - f"例如:'Data Processing', 'Model Training', 'Feature Extraction' 等。" - ) - - # 如果用户要求中文,但 SVG 不包含中文,给出提示(警告而非错误) - if required_lang in ["zh", "chinese", "cn", "中文"] and not has_chinese: - # 这里不返回失败,因为可能是纯图形的 SVG - # 但可以在日志中记录 - log.warning(f"用户要求使用中文(language={required_lang}),但生成的 SVG 不包含中文字符") - - return True, "" - - def validate_svg_renderable(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证 SVG 是否可以成功渲染""" - import tempfile - import os - - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return True, "" # 如果没有 SVG,跳过此验证 - - # 提取 SVG 片段 - fragment = _extract_svg_fragment(svg_code) - - # 注入中文字体(模拟 workflow 的行为) - fragment_with_font = _inject_chinese_font_for_validation(fragment) - - # 尝试渲染 SVG - try: - from dataflow_agent.toolkits.multimodaltool.bg_tool import local_tool_for_svg_render - - # 创建临时文件用于渲染测试 - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: - tmp_path = tmp_file.name - - try: - local_tool_for_svg_render({ - "svg_code": fragment_with_font, - "output_path": tmp_path, - }) - # 渲染成功,删除临时文件 - if os.path.exists(tmp_path): - os.remove(tmp_path) - return True, "" - except Exception as e: - # 渲染失败,删除临时文件 - if os.path.exists(tmp_path): - os.remove(tmp_path) - return False, f"SVG 渲染失败: {str(e)}。请检查 SVG 代码的格式和结构,确保所有属性值都被正确引号包裹。" - except Exception as e: - return False, f"渲染验证过程出错: {str(e)}" - - return [validate_svg_structure, validate_chinese_font, validate_language_requirement, validate_svg_renderable] - - def update_state_result( - self, - state: Paper2FigureState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - svg_code = None - if isinstance(result, dict): - svg_code = result.get("svg_code") - state.figure_tec_svg_bw_content = svg_code or "" - super().update_state_result(state, result, pre_tool_results) - - -def create_technical_route_bw_svg_generator( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - parser_type: str = "json", - **kwargs, -) -> TechnicalRouteBWSvgGenerator: - """ - 创建技术路线图黑白 SVG 生成器。 - - 注意: 不再使用 VLM (视觉语言模型),而是通过 pre_tool 提供 SVG 模板代码作为文本输入。 - """ - if tool_manager is None: - from dataflow_agent.toolkits.tool_manager import get_tool_manager - tool_manager = get_tool_manager() - - return TechnicalRouteBWSvgGenerator( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - parser_type=parser_type, - use_vlm=False, # 不再使用 VLM - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/technical_route_colorize_svg_agent.py b/dataflow_agent/agentroles/paper2any_agents/technical_route_colorize_svg_agent.py deleted file mode 100644 index 83e191a..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/technical_route_colorize_svg_agent.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -technical_route_colorize_svg_agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -基于黑白 SVG + 色卡上色生成彩色技术路线图 SVG -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional, List, Tuple - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - - -@register("technical_route_colorize_svg") -class TechnicalRouteColorizeSvgAgent(BaseAgent): - """对黑白 SVG 进行配色的 Agent""" - - @property - def role_name(self) -> str: - return "technical_route_colorize_svg" - - @property - def system_prompt_template_name(self) -> str: - return "system_prompt_for_technical_route_colorize_svg" - - @property - def task_prompt_template_name(self) -> str: - return "task_prompt_for_technical_route_colorize_svg" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - return { - "bw_svg_code": pre_tool_results.get("bw_svg_code", ""), - "palette_json": pre_tool_results.get("palette_json", ""), - "validation_feedback": pre_tool_results.get("validation_feedback", ""), - "color_template_svg": pre_tool_results.get("color_template_svg", ""), - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - return { - "bw_svg_code": "", - "palette_json": "", - "validation_feedback": "", - "color_template_svg": "", - } - - def get_react_validators(self) -> List: - """返回 ReAct 模式下使用的验证器列表""" - - # 获取用户要求的语言(用于语言验证) - required_lang = getattr(getattr(self.state, "request", None), "language", "en").lower() - - def _extract_svg_fragment(svg_code: str) -> str: - """提取干净的 SVG 片段""" - if not svg_code: - return "" - text = svg_code.strip() - if text.startswith("```"): - lines = [line for line in text.splitlines() if line.strip("`").strip()] - for i, line in enumerate(lines): - if "<svg" in line or "<SVG" in line: - text = "\n".join(lines[i:]) - break - start = text.find("<svg") - end = text.rfind("</svg>") - if start == -1 or end == -1: - return text - end += len("</svg>") - return text[start:end].strip() - - def _inject_chinese_font_for_validation(svg_code: str) -> str: - """为验证注入中文字体(与 workflow 中的逻辑一致)""" - import re - - # 检查是否包含中文字符 - has_chinese = bool(re.search(r'[\u4e00-\u9fff]', svg_code)) - if not has_chinese: - return svg_code - - # 中文友好字体列表 - chinese_fonts = 'Noto Sans CJK SC, Microsoft YaHei, SimHei, SimSun, WenQuanYi Zen Hei, sans-serif' - - # 替换所有 font-family 属性 - svg_code = re.sub( - r'font-family="[^"]*"', - f'font-family="{chinese_fonts}"', - svg_code - ) - - # 注入全局样式 - idx = svg_code.find(">") - if idx != -1: - style_block = f""" - <style type="text/css"> - text, tspan {{ - font-family: {chinese_fonts} !important; - }} - </style> -""" - svg_code = svg_code[:idx + 1] + style_block + svg_code[idx + 1:] - - return svg_code - - def validate_svg_structure(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证 SVG 基本结构""" - import xml.etree.ElementTree as ET - - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return False, "缺少 svg_code 字段或内容为空" - - fragment = _extract_svg_fragment(svg_code) - if "<svg" not in fragment.lower(): - return False, "未检测到 <svg> 根标签" - - if "viewBox" not in fragment and "viewbox" not in fragment: - return False, "缺少 viewBox 属性,请添加 viewBox" - - try: - ET.fromstring(fragment) - except Exception as e: - return False, f"SVG XML 解析失败: {e}" - - return True, "" - - def validate_chinese_font(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证中文字体设置""" - import re - - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return True, "" # 如果没有 SVG,跳过此验证 - - # 检查是否包含中文字符 - has_chinese = bool(re.search(r'[\u4e00-\u9fff]', svg_code)) - if not has_chinese: - return True, "" # 没有中文,跳过验证 - - # 检查是否使用了不支持中文的字体 - bad_fonts = ["Arial", "Helvetica", "Times", "Courier"] - for font in bad_fonts: - if f'font-family="{font}"' in svg_code or f"font-family='{font}'" in svg_code: - return False, ( - f"检测到中文文本但使用了不支持中文的字体 {font}。" - f"请使用中文友好字体,如:'Noto Sans CJK SC', 'Microsoft YaHei', 'SimHei', 'SimSun', sans-serif" - ) - - return True, "" - - def validate_colors_applied(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证配色是否正确应用""" - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return False, "缺少 svg_code 字段" - - # 检查是否还有黑白灰色(应该被替换为配色) - if 'fill="black"' in svg_code or 'fill="#000"' in svg_code or 'fill="#000000"' in svg_code: - return False, "检测到黑色填充,请使用配色方案中的颜色替换所有黑色" - - return True, "" - - def validate_language_requirement(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证 SVG 文本语言是否符合用户要求""" - import re - - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return True, "" # 如果没有 SVG,跳过此验证 - - # 检查是否包含中文字符 - has_chinese = bool(re.search(r'[\u4e00-\u9fff]', svg_code)) - - # 如果用户要求英文,但 SVG 包含中文,则验证失败 - if required_lang in ["en", "english"] and has_chinese: - return False, ( - f"用户要求使用英文(language={required_lang}),但生成的 SVG 包含中文字符。" - f"请严格按照用户要求的语言生成 SVG,所有文本标签必须使用英文。" - f"例如:'Data Processing', 'Model Training', 'Feature Extraction' 等。" - ) - - # 如果用户要求中文,但 SVG 不包含中文,给出提示(警告而非错误) - if required_lang in ["zh", "chinese", "cn", "中文"] and not has_chinese: - # 这里不返回失败,因为可能是纯图形的 SVG - # 但可以在日志中记录 - log.warning(f"用户要求使用中文(language={required_lang}),但生成的 SVG 不包含中文字符") - - return True, "" - - def validate_svg_renderable(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """验证 SVG 是否可以成功渲染""" - import tempfile - import os - - svg_code = parsed_result.get("svg_code", "") - if not svg_code: - return True, "" # 如果没有 SVG,跳过此验证 - - # 提取 SVG 片段 - fragment = _extract_svg_fragment(svg_code) - - # 注入中文字体(模拟 workflow 的行为) - fragment_with_font = _inject_chinese_font_for_validation(fragment) - - # 尝试渲染 SVG - try: - from dataflow_agent.toolkits.multimodaltool.bg_tool import local_tool_for_svg_render - - # 创建临时文件用于渲染测试 - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: - tmp_path = tmp_file.name - - try: - local_tool_for_svg_render({ - "svg_code": fragment_with_font, - "output_path": tmp_path, - }) - # 渲染成功,删除临时文件 - if os.path.exists(tmp_path): - os.remove(tmp_path) - return True, "" - except Exception as e: - # 渲染失败,删除临时文件 - if os.path.exists(tmp_path): - os.remove(tmp_path) - return False, f"SVG 渲染失败: {str(e)}。请检查 SVG 代码的格式和结构,确保所有属性值都被正确引号包裹。" - except Exception as e: - return False, f"渲染验证过程出错: {str(e)}" - - return [validate_svg_structure, validate_chinese_font, validate_colors_applied, validate_language_requirement, validate_svg_renderable] - - def update_state_result( - self, - state: Paper2FigureState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - svg_code = None - if isinstance(result, dict): - svg_code = result.get("svg_code") - state.figure_tec_svg_color_content = svg_code or "" - super().update_state_result(state, result, pre_tool_results) - - -def create_technical_route_colorize_svg_agent( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - parser_type: str = "json", - **kwargs, -) -> TechnicalRouteColorizeSvgAgent: - if tool_manager is None: - from dataflow_agent.toolkits.tool_manager import get_tool_manager - tool_manager = get_tool_manager() - - return TechnicalRouteColorizeSvgAgent( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - parser_type=parser_type, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/technical_route_desc_generator_agent.py b/dataflow_agent/agentroles/paper2any_agents/technical_route_desc_generator_agent.py deleted file mode 100644 index 44769f7..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/technical_route_desc_generator_agent.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -TechnicalRouteDescGenerator agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-08 00:44:59 -生成位置: dataflow_agent/agentroles/common_agents/technical_route_desc_generator_agent.py - -本文件由 `dfa create --agent_name technical_route_desc_generator` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional, List, Tuple - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("technical_route_desc_generator") -class TechnicalRouteDescGenerator(BaseAgent): - """TODO: 描述 technical_route_desc_generator 的职责""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "technical_route_desc_generator" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_technical_route_desc_generator" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_technical_route_desc_generator" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数 - 提示词中的占位符: - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - """ - # TODO: 按需补充 - return { - "paper_idea": pre_tool_results.get("paper_idea", ""), - "style": pre_tool_results.get("style", ""), - "lang": self.state.request.language - } - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- ReAct 验证器 ---------- - def get_react_validators(self) -> List: - """ - 返回 ReAct 模式下使用的验证器列表。 - - 验证器签名固定为: - validator(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str] - 这里基于 parsed_result["svg_code"] 做 SVG XML 合法性检查。 - """ - - def _extract_svg_fragment(svg_code: str) -> str: - """ - 从模型返回的字符串中,提取出干净的 <svg>...</svg> 片段。 - - 会做几件事: - 1. 去掉首尾空白 - 2. 去掉可能的 ``` 或 ```svg 代码块包裹 - 3. 只保留从第一个 <svg 开始到最后一个 </svg> 结束的部分 - """ - if not svg_code: - return "" - - text = svg_code.strip() - - # 处理 ```svg ... ``` 或 ``` ... ``` 代码块 - if text.startswith("```"): - lines = [line for line in text.splitlines() if line.strip("`").strip()] - # 找到第一行包含 <svg 的 - for i, line in enumerate(lines): - if "<svg" in line or "<SVG" in line or "<Svg" in line: - text = "\n".join(lines[i:]) - break - - # 定位 <svg ...> 到 </svg> 的区间 - start = text.find("<svg") - end = text.rfind("</svg>") - if start == -1 or end == -1: - # 找不到明确片段时,直接返回原始文本,交给 XML 解析报错 - return text - - end += len("</svg>") - return text[start:end].strip() - - def validate_svg(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, str]: - """ - technical_route_desc_generator 的 SVG 验证器。 - - 使用解析后的结果 parsed_result["svg_code"] 做 XML 级合法性检查, - 与 base_agent._run_validators(content, parsed_result) 的调用约定保持一致。 - """ - import xml.etree.ElementTree as ET - - svg_code = None - if isinstance(parsed_result, dict): - svg_code = parsed_result.get("svg_code") - - if not svg_code: - return False, "缺少 svg_code 字段或内容为空" - - fragment = _extract_svg_fragment(svg_code) - if "<svg" not in fragment and "<SVG" not in fragment: - return False, "返回内容中未检测到 <svg> 根标签" - - try: - # XML 解析只关心 well‑formed,不关心命名空间等 - ET.fromstring(fragment) - except Exception as e: - log.warning(f"technical_route_desc_generator.validate_svg: SVG XML 解析失败: {e}") - return False, f"SVG 不是合法 XML: {e}" - - return True, "SVG 验证通过" - - return [validate_svg] - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写 - - 期望 LLM 返回形如: - {"svg_code": "<svg ...>...</svg>"} - """ - # 解析 LLM 返回的 JSON 结果,取出 svg_code - svg_code = None - if isinstance(result, dict): - svg_code = result.get("svg_code", None) - state.figure_tec_svg_content = svg_code - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def technical_route_desc_generator( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """technical_route_desc_generator 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = TechnicalRouteDescGenerator( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_technical_route_desc_generator( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> TechnicalRouteDescGenerator: - return TechnicalRouteDescGenerator.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/agentroles/paper2any_agents/topic_writer_agent.py b/dataflow_agent/agentroles/paper2any_agents/topic_writer_agent.py deleted file mode 100644 index c0e8341..0000000 --- a/dataflow_agent/agentroles/paper2any_agents/topic_writer_agent.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -TopicWriter agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-27 02:40:27 -生成位置: dataflow_agent/agentroles/common_agents/topic_writer_agent.py - -本文件由 `dfa create --agent_name topic_writer` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("topic_writer") -class TopicWriter(BaseAgent): - """TODO: 描述 topic_writer 的职责""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "topic_writer" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_topic_writer" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_topic_writer" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数""" - target_pages = getattr(self.state, "target_pages", 60) - - # 动态计算目标字符数 - is_en = self._is_english_text(self.state.current_text) if self.state.current_text else False - chars_per_page = 3000 if is_en else 800 - target_chars = target_pages * chars_per_page - - return { - 'text_content': self.state.text_content, - 'generation_round': pre_tool_results.get('generation_round', 0), - 'language': getattr(self.state.request, "language", "中文"), - 'target_pages': target_pages, - 'target_chars': target_chars, - } - - def _is_english_text(self, text: str) -> bool: - """判断文本是否为英文""" - if not text: - return False - sample = text[:5000] - ascii_count = sum(1 for c in sample if ord(c) < 128) - return (ascii_count / len(sample)) > 0.8 - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将生成的文本写回 state.text_content""" - state.text_content = result['text'] - super().update_state_result(state, result, pre_tool_results) -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def topic_writer( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """topic_writer 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = TopicWriter( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_topic_writer( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> TopicWriter: - return TopicWriter.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/cli.py b/dataflow_agent/cli.py deleted file mode 100644 index 9c3b681..0000000 --- a/dataflow_agent/cli.py +++ /dev/null @@ -1,187 +0,0 @@ -import click -from pathlib import Path -from datetime import datetime -from jinja2 import Template -from typing import Optional -from dataflow_agent.logger import get_logger -log = get_logger(__name__) -TEMPLATE_DIR = Path(__file__).parent / "templates" - -# ---------- util ---------- -def to_snake(s: str) -> str: - import re - s = re.sub(r'[\- ]+', '_', s).strip('_') - parts = re.split(r'[_]', s) - return '_'.join(p.lower() for p in parts if p) - -def to_camel(s: str) -> str: - return ''.join(p.capitalize() for p in to_snake(s).split('_')) - -# ---------- CLI ---------- -@click.group() -def cli(): - """DataFlow-Agent command line.""" - pass - - -@cli.command("create") -@click.option("--wf_name", help="要创建的 workflow 名称") -@click.option("--agent_name", help="要创建的 agent 名称") -@click.option("--gradio_name", help="要创建的 gradio page 名称") -@click.option("--prompt_name", help="要创建的 prompt template 名称") -@click.option("--agent_as_tool_name", help="要创建的 agent-as-tool 名称") -@click.option("--state_name", help="要创建的 state 名称") -def create_artifact(wf_name: Optional[str] = None, - agent_name: Optional[str] = None, - gradio_name: Optional[str] = None, - prompt_name: Optional[str] = None, - agent_as_tool_name: Optional[str] = None, - state_name: Optional[str] = None): - """ - dfa create --wf_name xxx - dfa create --agent_name yyy - dfa create --gradio_name zzz - dfa create --prompt_name zzz - dfa create --agent_as_tool_name aaa - dfa create --state_name bbb - """ - opts = [bool(wf_name), bool(agent_name), bool(gradio_name), bool(prompt_name), bool(agent_as_tool_name), bool(state_name)] - if sum(opts) != 1: - click.echo(" --wf_name / --agent_name / --gradio_name / --prompt_name / --agent_as_tool_name / --state_name 必须且只能选一个", err=True) - raise SystemExit(1) - - timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # ------------------------------------------------------------------ - # 1. Workflow - # ------------------------------------------------------------------ - if wf_name: - wf_name_snake = to_snake(wf_name) - - # 1.1 workflow 源码 - wf_dest = Path(__file__).parent / "workflow" / f"wf_{wf_name_snake}.py" - wf_tpl_path = TEMPLATE_DIR / "workflow.py.jinja" - wf_context = dict( - wf_name=wf_name, - wf_name_snake=wf_name_snake, - entry="step1", - timestamp=timestamp, - ) - _generate_file(wf_dest, wf_tpl_path, wf_context, "workflow") - - # 1.2 对应测试 - project_root = Path(__file__).parent.parent - test_dest = project_root / "tests" / f"test_{wf_name_snake}.py" - test_tpl = TEMPLATE_DIR / "test_workflow.py.jinja" - test_ctx = dict( - wf_name=wf_name, - wf_name_snake=wf_name_snake, - timestamp=timestamp, - ) - _generate_file(test_dest, test_tpl, test_ctx, "test") - - # ------------------------------------------------------------------ - # 2. Agent - # ------------------------------------------------------------------ - elif agent_name: - agent_snake = to_snake(agent_name) - dest = Path(__file__).parent / "agentroles" / "common_agents" / f"{agent_snake}_agent.py" - tpl_path = TEMPLATE_DIR / "agent.py.jinja" - ctx = dict( - agent_name=agent_name, - agent_name_snake=agent_snake, - agent_name_camel=to_camel(agent_name), - timestamp=timestamp, - ) - _generate_file(dest, tpl_path, ctx, "agent") - - # ------------------------------------------------------------------ - # 3. Gradio Page - # ------------------------------------------------------------------ - elif gradio_name: - page_snake = to_snake(gradio_name) - project_root = Path(__file__).parent.parent - dest = project_root / "gradio_app" / "pages" / f"page_{page_snake}.py" - tpl_path = TEMPLATE_DIR / "gradio_page.py.jinja" - ctx = dict( - page_name=gradio_name, - page_name_snake=page_snake, - timestamp=timestamp, - ) - _generate_file(dest, tpl_path, ctx, "gradio page") - - # ------------------------------------------------------------------ - # 4. Prompt Template - # ------------------------------------------------------------------ - elif prompt_name: - prompt_snake = to_snake(prompt_name) - dest = Path(__file__).parent / "promptstemplates" / "resources" / f"pt_{prompt_snake}_repo.py" - tpl_path = TEMPLATE_DIR / "prompt_repo.py.jinja" - ctx = dict( - prompt_name=prompt_name, - prompt_name_snake=prompt_snake, - prompt_name_camel=to_camel(prompt_name), - timestamp=timestamp, - ) - _generate_file(dest, tpl_path, ctx, "prompt template") - - # ------------------------------------------------------------------ - # 5. Agent-as-Tool - # ------------------------------------------------------------------ - elif agent_as_tool_name: - agent_snake = to_snake(agent_as_tool_name) - dest = Path(__file__).parent / "agentroles" / "common_agents" / f"{agent_snake}_agent.py" - tpl_path = TEMPLATE_DIR / "agent_as_tool_name.py.jinja" - ctx = dict( - agent_name=agent_as_tool_name, - agent_name_snake=agent_snake, - agent_name_camel=to_camel(agent_as_tool_name), - timestamp=timestamp, - ) - _generate_file(dest, tpl_path, ctx, "agent-as-tool") - - # ------------------------------------------------------------------ - # 6. State - # ------------------------------------------------------------------ - else: - state_snake = to_snake(state_name) - dest = Path(__file__).parent / "states" / f"{state_snake}_state.py" - tpl_path = TEMPLATE_DIR / "state_name.py.jinja" - ctx = dict( - state_name=state_name, - state_name_snake=state_snake, - state_name_camel=to_camel(state_name), - timestamp=timestamp, - ) - _generate_file(dest, tpl_path, ctx, "state") - - -# ---------- helper ---------- -def _generate_file(dest: Path, tpl_path: Path, context: dict, file_type: str): - """ - 通用文件生成函数 - """ - dest.parent.mkdir(parents=True, exist_ok=True) - - if dest.exists(): - # click.echo(f" {dest} 已存在,跳过生成") - log.error(f" {dest} 已存在,跳过生成") - return - - if not tpl_path.exists(): - log.error(f" 模板不存在: {tpl_path}", err=True) - raise SystemExit(1) - - rendered = Template(tpl_path.read_text(encoding="utf-8")).render(**context) - dest.write_text(rendered, encoding="utf-8") - - try: - rel_path = dest.relative_to(Path.cwd()) - except ValueError: - rel_path = dest - - log.critical(f" 已生成 {file_type}: {rel_path}") - - -if __name__ == "__main__": - cli() diff --git a/dataflow_agent/promptstemplates/resources/__init__.py b/dataflow_agent/promptstemplates/resources/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/promptstemplates/resources/pt_bbox_agent_repo.py b/dataflow_agent/promptstemplates/resources/pt_bbox_agent_repo.py deleted file mode 100644 index 2698d21..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_bbox_agent_repo.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Prompt Templates for bbox_agent -Generated at: 2026-01-12 19:07:36 -""" - -# --------------------------------------------------------------------------- # -# 1. BboxAgent - bbox_agent 相关提示词 -# --------------------------------------------------------------------------- # -class BboxAgent: - """ - bbox_agent 任务的提示词模板 - """ - - system_prompt_for_image_text_bbox_agent = """ -你是一个强大的多模态视觉理解 AI 助手。 -你的任务是分析图像,提取其中所有的文本内容及其精确的边界框(Bounding Box)。 -""" - - task_prompt_for_image_text_bbox_agent = """ -请执行高精文字检测与识别任务: -1. 提取图像中所有的文字内容。 -2. 为每一行文字提供精确的边界框坐标(location)。 -3. rotate_rect坐标!!! -4. 不要任何```json包裹!!,直接返回文本格式json字符串! -JSON 结构如下: - -[ - { - "rotate_rect": [500, 48, 63, 791, 90], - "text": "Cartoon-style Mechanistic Overview of T Cell Generation" - } - xxx -] - -""" diff --git a/dataflow_agent/promptstemplates/resources/pt_kb_prompt_agent_repo.py b/dataflow_agent/promptstemplates/resources/pt_kb_prompt_agent_repo.py deleted file mode 100644 index 7199031..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_kb_prompt_agent_repo.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -Prompt Templates for kb_prompt_agent -""" - -system_prompt_for_kb_prompt_agent = """ -You are a helpful assistant. -Follow the user's instruction carefully and provide a concise, high-quality response. -""" - -task_prompt_for_kb_prompt_agent = """ -{prompt} -""" - diff --git a/dataflow_agent/promptstemplates/resources/pt_long_paper_repo.py b/dataflow_agent/promptstemplates/resources/pt_long_paper_repo.py deleted file mode 100644 index d3a60c6..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_long_paper_repo.py +++ /dev/null @@ -1,211 +0,0 @@ -""" -Prompt Templates for long_paper -Generated at: 2025-12-27 02:15:52 -""" - -# --------------------------------------------------------------------------- # -# 1. LongPaperOutlineAgent - 提示词 -# --------------------------------------------------------------------------- # -class LongPaperOutlineAgent: - """ - long_paper_outline_agent 任务的提示词模板 - """ - - # 系统提示词:与普通 outline_agent 共享或专用 - system_prompt_for_outline_agent = """ -你是一位拥有丰富学术汇报经验的PPT设计专家及大纲生成助手。你的核心任务是将一篇学术论文(或长文档的一部分)转化为一份逻辑清晰、视觉布局合理的PPT演示大纲。 - -请遵循以下严格规则: -1. **深度理解**:仔细阅读用户提供的文本内容,提取核心论点、实验数据和结论。 -2. **视觉导向**:在规划每一页PPT时,不仅要生成文字内容,必须明确指出该页是否需要展示特定的插图(Images)或表格(Tables)。 -3. **布局建议**:为每一页提供具体的布局指导(例如:左文右图、上标题下表格、两栏对比等)。 -4. **格式严格**:输出必须且只能是标准的 JSON 格式数组。严禁包含 markdown 标记(如 ```json)、前言、后语或任何非 JSON 字符。 -""" - - # 1. 首页 Prompt (Is First Batch) - task_prompt_for_long_paper_outline_agent_first = """ -这是长文档分批生成 PPT 的**第一批次**。 -当前进度:第 {batch_index} 批 / 共 {total_batches} 批。 -本批次目标页数:{pages_to_generate} 页(包括封面)。 -总目标页数:{page_count} 页。 - -**输入数据(当前文本片段):** -{current_chunk} - -**任务要求:** -1. **第一页必须是封面**:包含 PPT 主题(Title)和汇报人信息(Presenter)。不需要额外的内容。 -2. 后续页面开始进入正文介绍(如背景、引言、核心问题等)。 -3.输出内容的语言为 **{language}**。 -4. 不需要致谢页(除非文本很短,这是唯一一批)。 - -**输出格式要求(JSON Array):** -请返回一个 JSON 数组,数组中每个对象代表一页PPT,结构如下: -- `title`: 该页PPT的标题。 -- `layout_description`: 详细的版面布局描述。 -- `key_points`: 一个包含多个关键要点的字符串列表(List<String>)。 -- `asset_ref`: 如果该页需要展示论文中的原图或表格,请提名或路径取其文件(例如 "Table_2", "images/architecture.png"),并且只能1 个 asset;如果不需要引用原图,请填 null。 - -示例结构: -[ - {{ - "title": "大语言模型的幻觉问题研究", - "layout_description": "封面设计,居中放置大号标题,下方为汇报人姓名。", - "key_points": ["汇报人:DataFlow Agent"], - "asset_ref": null - }}, - {{ - "title": "研究背景", - "layout_description": "左侧文字介绍,右侧配图。", - "key_points": ["大模型幻觉的定义。", "当前面临的挑战。"], - "asset_ref": "images/intro.png" - }} -] -""" - - # 2. 中间页 Prompt (Middle Batch) - task_prompt_for_long_paper_outline_agent_middle = """ -这是长文档分批生成 PPT 的**中间批次**。 -当前进度:第 {batch_index} 批 / 共 {total_batches} 批。 -本批次目标页数:{pages_to_generate} 页。 - -**输入数据(当前文本片段):** -{current_chunk} - -**任务要求:** -1. **直接生成正文内容**:不需要封面,也不要致谢。 -2. 承接上一批次的内容,继续展开当前的章节。 -3. 如果文本包含新的章节标题,请作为新的一页或新章节的开始。 -4. 输出内容的语言为 **{language}**。 - -**输出格式要求(JSON Array):** -JSON 数组,每个对象代表一页PPT。 -结构字段:`title`, `layout_description`, `key_points`, `asset_ref`。 -- `asset_ref`: 如果该页需要展示论文中的原图或表格,请提名或路径取其文件(例如 "Table_2", "images/architecture.png"),并且只能1 个 asset;如果不需要引用原图,请填 null。 - -示例结构: -[ - { - "title": "Methodology: Overview", - "layout_description": "Top-down layout: brief textual overview at the top, followed by a large pipeline diagram showing the main components of the approach.", - "key_points": ["Provide a high-level description of the proposed method.", "List the key components or stages of the approach."], - "asset_ref": "images/method_pipeline.png" - }, - { - "title": "Experimental Setup", - "layout_description": "Two-column layout: left for text describing datasets, right for a simple table.", - "key_points": ["Mention the main datasets or benchmarks used.", "Briefly describe the experimental environment."], - "asset_ref": "Table_1" - } -] -""" - - # 3. 尾页 Prompt (Is Last Batch) - task_prompt_for_long_paper_outline_agent_last = """ -这是长文档分批生成 PPT 的**最后一批次**。 -当前进度:第 {batch_index} 批 / 共 {total_batches} 批。 -本批次目标页数:{pages_to_generate} 页(包括致谢)。 - -**输入数据(当前文本片段):** -{current_chunk} - -**任务要求:** -1. 生成剩余的正文内容(结论、未来展望等)。 -2. **最后一页必须是致谢(Thank You)**:简短的结束语。 -3.输出内容的语言为 **{language}**。 - -**输出格式要求(JSON Array):** -JSON 数组,每个对象代表一页PPT。 -结构字段:`title`, `layout_description`, `key_points`, `asset_ref`。 -- `asset_ref`: 如果该页需要展示论文中的原图或表格,请提名或路径取其文件(例如 "Table_2", "images/architecture.png"),并且只能1 个 asset;如果不需要引用原图,请填 null。 - -确保最后一页是致谢页。 - -示例结构: -[ - { - "title": "Conclusion", - "layout_description": "Two-column layout: left column bullet points, right column illustrative figure.", - "key_points": ["Summarize the main contributions.", "Highlight the effectiveness of the proposed method."], - "asset_ref": "images/conclusion_chart.png" - }, - { - "title": "Thank You", - "layout_description": "Title centered; minimal content with short closing remark and optional contact/info line.", - "key_points": ["Thank you for your attention.", "Q&A"], - "asset_ref": null - } -] -""" - - -# --------------------------------------------------------------------------- # -# 2. ContentExpander - 提示词 -# --------------------------------------------------------------------------- # -class ContentExpander: - """ - content_expander 任务的提示词模板 - """ - - system_prompt_for_content_expander = """ -你是一个专业的学术写作助手和内容扩写专家。你的任务是将输入的简短文本或草稿,扩写成篇幅更长、细节更丰富、逻辑更严密的文章或报告。 -你的扩写应保持专业性,增加必要的背景介绍、详细的解释、具体的例子或论证,以满足生成长篇 PPT 的内容需求。 -""" - - task_prompt_for_content_expander = """ -**当前任务:** -对以下文本进行第 {expansion_round} 轮扩写。 - -**输入文本:** -{text_content} - -**扩写要求:** -1. **大幅增加篇幅**:在保持原意的前提下,通过增加细节、举例、背景分析、优缺点对比等方式,显著增加字数。 -2. **结构完整**:如果输入是片段,请将其补全为完整的章节;如果输入是提纲,请将其展开为全文。 -3. **保持连贯**:确保扩写后的内容逻辑通顺,段落过渡自然。 -4. **输出限制**:直接输出扩写后的完整文本,不要包含任何类似于“好的,这是扩写后的内容”的废话。不要使用 Markdown 代码块包裹。 -5. 如果需要表格,必须输出md表格内容,Table_1, xxx -请开始扩写: - -""" - - -# --------------------------------------------------------------------------- # -# 3. TopicWriter - 提示词 -# --------------------------------------------------------------------------- # -class TopicWriter: - """ - topic_writer 任务的提示词模板 - 用于根据 Topic 生成长篇研究报告 - """ - - system_prompt_for_topic_writer = """ -你是一位资深的学术研究员和技术写作专家。你的任务是根据给定的主题(Topic),撰写一份详细、专业、结构完整的研究报告或技术文档。 - -你的写作应该: -1. 内容丰富、逻辑严密、论证充分 -2. 包含必要的背景介绍、核心概念、方法论、应用场景等 -3. 适合用于生成长篇 PPT 演示文稿 -""" - - task_prompt_for_topic_writer = """ -**任务:** 根据以下主题生成详细的研究报告(第 {generation_round} 轮生成) - -**主题:** -{text_content} - -**生成要求:** -1. **语言**:使用 {language} 语言撰写 -2. **篇幅**:大幅扩展内容,目标字数应达到支持 {target_pages} 页 PPT 的长度 - - 目标字符数:约 {target_chars} 字符 -3. **结构**: - - 包含完整的引言、背景、核心内容、结论等章节 - - 每个章节都要详细展开,提供具体的例子、数据、分析 -4. **内容深度**: - - 如果是第一轮生成,从主题出发构建完整框架 - - 如果是后续轮次,在现有内容基础上继续扩展和深化 - - 不要重复已有内容,而是增加新的维度和细节 - -**输出格式:** -1.直接输出完整的研究报告文本,不要包含任何说明性文字或 Markdown 代码块标记. -2.如果需要表格,可以输出md表格内容 -""" diff --git a/dataflow_agent/promptstemplates/resources/pt_operator_qa_repo.py b/dataflow_agent/promptstemplates/resources/pt_operator_qa_repo.py deleted file mode 100644 index 366f1f5..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_operator_qa_repo.py +++ /dev/null @@ -1,151 +0,0 @@ -""" -Prompt Templates for OperatorQA Agent -Generated at: 2025-12-01 15:05:13 - -本文件定义了算子问答 Agent 的提示词模板。 -""" - -# --------------------------------------------------------------------------- # -# OperatorQA - 算子问答相关提示词 -# --------------------------------------------------------------------------- # -class OperatorQAPrompts: - """ - 算子问答 Agent 的提示词模板 - - 用于支持: - 1. 自然语言查询算子功能 - 2. 查询特定算子做什么 - 3. 查询算子参数含义 - 4. 查看算子源码 - 5. 多轮对话 - """ - - system_prompt_for_operator_qa = """ -[角色] -你是 DataFlow 算子库的智能问答助手。你的职责是帮助用户了解和使用 DataFlow 中的各种数据处理算子。 - -[能力] -1. 根据用户描述的需求,推荐合适的算子 -2. 解释算子的功能、用途和使用场景 -3. 详细说明算子的参数含义和配置方法 -4. 在需要时展示算子的源码实现 -5. 基于多轮对话理解用户的上下文需求 - -[DataFlow 算子简介] -DataFlow 是一个数据处理框架,提供了丰富的算子用于数据清洗、过滤、生成、评估等任务。 -每个算子都是一个 Python 类,通常包含: -- `__init__` 方法:初始化算子,配置必要的参数(如 LLM 服务、提示词等) -- `run` 方法:执行数据处理逻辑,接收输入数据并产出处理结果 - -[可用工具] -你可以调用以下工具来获取算子信息: - -**算子相关工具:** -1. **search_operators(query, top_k)** - 根据功能描述搜索相关算子 - - 当用户询问某类功能的算子时使用 - - 如果对话历史中已有相关算子信息,可以不调用直接回答 - -2. **get_operator_info(operator_name)** - 获取指定算子的详细描述 - - 当用户询问特定算子的功能时使用 - -3. **get_operator_source_code(operator_name)** - 获取算子的完整源代码 - - 当用户需要了解算子实现细节时使用 - -4. **get_operator_parameters(operator_name)** - 获取算子的参数详情 - - 当用户询问算子如何配置、参数含义时使用 - - -**文件操作工具:** -- `read_text_file(file_path, start_line, end_line)`: 读取项目内的文本文件内容。可指定行范围。 -- `list_directory(dir_path, show_hidden, recursive)`: 查看项目目录结构。 - - -[工具调用策略] -- 如果是新问题且对话历史中没有相关信息 → 调用 search_operators 检索 -- 如果对话历史中已有相关算子信息 → 可以直接回答,无需重复检索 -- 如果用户追问某个算子的细节 → 调用 get_operator_info/get_operator_source_code/get_operator_parameters -- 如果用户追问代码实现、开发、部署等需要阅读源代码的问题,或询问整体架构,你可以通过多轮调用文件工具来查询信息 - - -[回答风格] -1. 清晰简洁,重点突出 -2. 使用中文回答(除非用户要求英文) -3. 对于技术细节,提供具体的代码示例 -4. 在解释参数时,说明参数类型、默认值和作用 - -[输出格式] -请以 JSON 格式返回,包含以下字段: -{{ - "answer": "对用户问题的详细回答", - "related_operators": ["相关算子名称列表"], - "source_explanation": "说明答案的信息来源,例如:'通过search_operators检索到的XXX算子'、'基于对话历史中的算子信息'、'基于我的知识库'", - "code_snippet": "如有必要,提供代码片段(可选)", - "follow_up_suggestions": ["可能的后续问题建议(可选)"] -}} -""" - - task_prompt_for_operator_qa = """ -[用户问题] -{user_query} - -[任务] -请根据用户问题回答。对话历史会自动包含在消息中,你可以参考之前的对话。 - -工具调用指南: -1. 如果需要查找算子,调用 search_operators 工具 -2. 如果需要某个算子的详细信息,调用 get_operator_info 工具 -3. 如果需要源码,调用 get_operator_source_code 工具 -4. 如果需要参数详情,调用 get_operator_parameters 工具 -5. 如果之前的对话中已有相关信息,可以直接回答,无需重复调用工具 - -回答要求: -- 基于工具返回的信息或对话上下文中的信息回答 -- 在 source_explanation 中说明答案来源 -- 如果问题不明确,可以在 follow_up_suggestions 中给出澄清建议 - -请以 JSON 格式返回你的回答。 -""" - - # 用于获取源码的追问提示词 - task_prompt_for_get_source = """ -[用户请求] -用户希望查看算子 "{operator_name}" 的源码。 - -[算子源码] -```python -{source_code} -``` - -[任务] -请简要说明这个算子的实现逻辑,并在 code_snippet 字段中返回完整源码。 - -请以 JSON 格式返回: -{{ - "answer": "对算子实现的简要说明", - "related_operators": ["{operator_name}"], - "code_snippet": "完整源码" -}} -""" - - # 用于解释参数的提示词 - task_prompt_for_explain_params = """ -[用户请求] -用户希望了解算子 "{operator_name}" 的参数详情。 - -[参数信息] -__init__ 参数: -{init_params} - -run 方法参数: -{run_params} - -[任务] -请详细解释每个参数的含义、类型、默认值和使用场景。 - -请以 JSON 格式返回: -{{ - "answer": "参数的详细说明", - "related_operators": ["{operator_name}"], - "code_snippet": "使用示例代码(如有必要)" -}} -""" diff --git a/dataflow_agent/promptstemplates/resources/pt_target_parse_repo.py b/dataflow_agent/promptstemplates/resources/pt_target_parse_repo.py deleted file mode 100644 index 92fc54c..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_target_parse_repo.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Prompt Templates for target_parse -Generated at: 2025-11-26 01:59:56 -""" - -# --------------------------------------------------------------------------- # -# 1. TargetParse - target_parse 相关提示词 -# --------------------------------------------------------------------------- # -class TargetParse: - """ - target_parse 任务的提示词模板 - """ - - task_prompt_for_example = """ -Your task description here. -Input: {input_data} -""" - - system_prompt_for_example = """ -You are an AI assistant for target_parse tasks. -""" \ No newline at end of file diff --git a/dataflow_agent/promptstemplates/resources/pt_tech_route_reference_analyzer.py b/dataflow_agent/promptstemplates/resources/pt_tech_route_reference_analyzer.py deleted file mode 100644 index a61ed1d..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_tech_route_reference_analyzer.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -Prompt Templates for tech_route_reference_analyzer -Generated at: 2025-01-26 -用于 VLM 分析技术路线图参考图 -""" - - -class TechRouteReferenceAnalyzer: - system_prompt_for_tech_route_reference_analyzer = """ -You are a Technical Route Diagram SVG Generator specialized in analyzing reference images. -Your task is to analyze the provided reference image and generate an SVG code that replicates its structure and style. - -OUTPUT FORMAT: -- You MUST output a strict JSON object: {"svg_code": "<svg ...>...</svg>"} -- No extra text, no Markdown, no explanations -- The SVG code must be complete and valid - -REQUIREMENTS: -1) Analyze the reference image's layout, node shapes, arrow styles, colors, and text positioning -2) Generate SVG code that closely matches the reference image's visual structure -3) SVG must include viewBox; width/height should be "100%" -4) Preserve the overall layout direction (horizontal/vertical/mixed) -5) Match node shapes, sizes, and spacing as closely as possible -6) Replicate arrow/connection styles (straight/curved, thickness, markers) -7) Use similar color schemes if the image is colored; use grayscale if it's black/white -""" - - task_prompt_for_tech_route_reference_analyzer = """ -Please analyze this reference technical route diagram image and generate SVG code that replicates its structure and style. - -ANALYSIS POINTS: -1. **Overall Layout**: Is it horizontal, vertical, or mixed? How many main stages? -2. **Node Styles**: What shapes (rect, rounded rect, circle, diamond)? Border thickness and colors? -3. **Arrow Styles**: Straight or curved lines? Arrow shapes and sizes? Line thickness? -4. **Color Scheme**: What colors are used? How do different stages/types differ in color? -5. **Text Layout**: Is text inside or outside nodes? Font sizes and colors? - -Based on your analysis, generate an SVG code that closely matches the reference image's visual structure. - -**LANGUAGE REQUIREMENT: {lang}** -- If {lang} is "en" or "EN": Use English for all text labels -- If {lang} is "zh" or "ZH": Use Chinese for all text labels - -Output only JSON {{"svg_code": "..."}}. -""" diff --git a/dataflow_agent/promptstemplates/resources/pt_technical_route_desc_generator_repo.py b/dataflow_agent/promptstemplates/resources/pt_technical_route_desc_generator_repo.py deleted file mode 100644 index 105289d..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_technical_route_desc_generator_repo.py +++ /dev/null @@ -1,507 +0,0 @@ -""" -Prompt Templates for technical_route_desc_generator -Generated at: 2025-12-08 01:19:09 -""" - -# --------------------------------------------------------------------------- # -# 1. TechnicalRouteDescGenerator - technical_route_desc_generator 相关提示词 -# --------------------------------------------------------------------------- # -class TechnicalRouteDescGenerator: - system_prompt_for_figure_desc_generator_free = """ -你是一位世界顶级的 CVPR/NeurIPS 视觉架构师 -你的核心能力是将晦涩难懂的论文逻辑,转化为**具体的、画面感极强的视觉描述。 -""" - - task_prompt_for_figure_desc_generator_free =""" -下面是一篇论文的核心研究内容(paper_idea): - -{paper_idea} - -请根据上述内容,编写一个用于 Text-to-Image 模型的英文提示词(Prompt)。 - -### 提示词编写策略: -1. 强调 “科研绘图”,用于论文插图; -2. **风格(Style)**:{style} - - 必须强制包含论文内容的关键词. -3. 白色背景,然后要分成多个panel,就跟论文中的图一样,每个panel都要有自己的标题,标题要放在panel的上方; -4. 信息量要丰富,填满整个画面; - -### 最终生成的 fig_desc 必须是一段连贯的英文描述; - -# Output Format (The Golden Schema) -请严格遵守以下 JSON 输出要求: - -1. 最终响应必须是一个严格合法的 JSON 对象,不能包含任何额外文字、解释或 Markdown 标记。 -2. 该 JSON 对象只能包含一个键:fig_desc。 -3. fig_desc 的值必须是一个字符串,用于描述整张图的视觉结构和内容。 -4. 在 JSON 中: - - 所有双引号必须写成 \\"; - - 所有换行必须写成 \\n(不能直接换行输出); - - 不要包含制表符或其它控制字符。 - -示例(仅示意结构,实际内容请根据论文生成): -{ - "fig_desc": "xxx" -} - - -""" - - - system_prompt_for_figure_desc_generator_mid = """ -# Role -你是一位 CVPR/NeurIPS 顶刊的视觉架构师。你的核心能力是将抽象的论文逻辑转化为具体的、结构化的、可直接用于绘图模型的视觉指令。 - -# Objective -阅读我提供的论文内容,输出一份 [VISUAL SCHEMA]。这份 Schema 将被直接发送给 AI 绘图模型,因此必须使用清晰的物理描述。 - -# Phase 1: Layout Strategy Selector (关键步骤:布局决策) -在生成 Schema 之前,请先分析论文逻辑,从以下布局原型中选择最合适的一个(或组合): -1. Linear Pipeline: 左→右流向 (适合 Data Processing, Encoding-Decoding)。 -2. Cyclic/Iterative: 中心包含循环箭头 (适合 Optimization, RL, Feedback Loops)。 -3. Hierarchical Stack: 上→下或下→上堆叠 (适合 Multiscale features, Tree structures)。 -4. Parallel/Dual-Stream: 上下平行的双流结构 (适合 Multi-modal fusion, Contrastive Learning)。 -5. Central Hub: 一个核心模块连接四周组件 (适合 Agent-Environment, Knowledge Graphs)。 - -# Phase 2: Schema Generation Rules -1. Dynamic Zoning: 根据选择的布局,定义 2-5 个物理区域 (Zones)。 -2. Internal Visualization: 必须定义每个区域内部的“物体”(图标、网格、树等),禁止仅使用抽象概念。 -3. Explicit Connections: 如果是循环过程,必须明确描述 "Curved arrow looping back from Zone X to Zone Y" 之类的连接。 - -# Output Format (The Golden Schema) -请严格遵守以下 JSON 输出要求: - -1. 最终响应必须是一个严格合法的 JSON 对象,不能包含任何额外文字、解释或 Markdown 标记。 -2. 该 JSON 对象只能包含一个键:fig_desc。 -3. fig_desc 的值必须是一个字符串,用于描述整张图的视觉结构和内容。 -4. 在 JSON 中: - - 所有双引号必须写成 \\"; - - 所有换行必须写成 \\n(不能直接换行输出); - - 不要包含制表符或其它控制字符。 - -示例(仅示意结构,实际内容请根据论文生成): -{ - "fig_desc": "[Style & Meta-Instructions] ... \\n[LAYOUT CONFIGURATION] ... \\n[ZONE 1: LOCATION - ...] ... \\n[CONNECTIONS] ..." -} - -在 fig_desc 字符串中,建议按照如下区块依次描述: -[Style & Meta-Instructions] -[LAYOUT CONFIGURATION] -[ZONE 1: LOCATION - LABEL] -[ZONE 2: LOCATION - LABEL] -[ZONE 3: LOCATION - LABEL] -[CONNECTIONS] - -# Input Data - -paper_idea -""" - - task_prompt_for_figure_desc_generator_mid = """ -**Style Reference & Execution Instructions:** - -1. Art Style (Visio/Illustrator Aesthetic): - Generate a professional academic architecture diagram suitable for a top-tier computer science paper (CVPR/NeurIPS). - - Visuals: Flat vector graphics, distinct geometric shapes, clean thin outlines, and soft pastel fills (Azure Blue, Slate Grey, Coral Orange). - - Layout: Strictly follow the spatial arrangement defined below. - - Vibe: Technical, precise, clean white background. NOT hand-drawn, NOT photorealistic, NOT 3D render, NO shadows/shading. - -2. CRITICAL TEXT CONSTRAINTS (Read Carefully): - - DO NOT render meta-labels: Do not write words like "ZONE 1", "LAYOUT CONFIGURATION", "Input", "Output", or "Container" inside the image. These are structural instructions for YOU, not text for the image. - - ONLY render "Key Text Labels": Only text inside double quotes (e.g., "[Text]") listed under "Key Text Labels" should appear in the diagram. - - Font: Use a clean, bold Sans-Serif font (like Roboto or Helvetica) for all labels. - -论文内容(paper_idea)如下: - -{paper_idea} - -要求提示词一定满足: -1. 信息丰富,信息量大; -2. 科研绘图,白色背景; -3. {style} 风格提示词; -4. 最重要,生成提示词要写入:“生成的文字都要在Icon旁边,不能覆盖Icon!!!!!” - -请基于上述论文内容和风格要求,设计对应的视觉架构指令,并按照系统提示中的 JSON 规范,仅输出一个 JSON 对象: -- 该对象只包含一个键:fig_desc; -- fig_desc 的值为完整的视觉描述字符串; -- 保证整个响应是严格合法的 JSON(双引号使用 \\" 转义,换行使用 \\n 转义),不要输出任何多余文本、注释或 Markdown 标记。 -""" - - # 用户/任务层提示:描述输入是什么 + 要求生成“复杂、美观、箭头明显”的技术路线图 SVG - - task_prompt_for_technical_route_desc_generator = """ -下面是一个论文的研究内容(paper_idea): - -{paper_idea} - -请根据该想法设计一份技术路线图,并用 SVG 代码进行表示。 - -整体要求(重要): -1. 技术路线图需要包括关键步骤/模块及其先后关系,**建议划分为 3~5 个清晰的阶段**,每个阶段内部可以包含 3~6 个节点,使整体结构**信息量丰富但有条理**。 -2. 每个步骤使用风格统一的节点形状(推荐圆角矩形),可以适度使用少量其他形状(如椭圆)突出起点/终点或关键模块,但整体视觉语言要统一。 -3. 流程连接必须使用**线条较粗、颜色对比明显的箭头**(可以是直线或略带弧度的 path),箭头头部要清晰可见,确保方向一眼可辨;允许存在分支和汇合。 -4. 布局建议采用自左向右或自上而下的多行/多列结构,可以通过阶段分组(背景块或分区标题)表现整体流程的层次感,使图看起来**结构清晰、相对复杂且完整**,但不要杂乱无章。 -5. 颜色风格要**美观、现代**:可以区分不同阶段或节点类别,适度使用渐变、阴影或圆角等效果增强观感,但要注意整体协调,避免刺眼的高饱和颜色充斥全图。 -6. 整体要在“信息量丰富、结构清晰”和“视觉美观”之间取得平衡,使得技术路线看起来**专业、完整,而不是极简草图**。 - -关于文字排布(非常重要): -1. 可以将简短的步骤名称放在节点内部(例如 1~4 个词),也可以在节点外侧(上/下/左/右)放置说明文字;两种方式可以结合使用,但要尽量保持同一层级的节点风格一致。 -2. 避免超长句子塞在一个节点中,尽量用简短短语或关键词表达(例如 “数据预处理”“特征工程”“模型训练”“消融实验”等)。 -3. 阶段标题可以使用比节点文字略大的字号,放在对应分区上方或左侧,强化层次结构。 - -SVG 复杂度与风格(非常重要): -1. 整体元素数量可以相对较多:包含多阶段背景块、若干节点、较多箭头和必要的装饰线条,以呈现出清晰而**相对复杂**的技术流程。 -2. 可以适度使用渐变、圆角、阴影、背景分区等视觉元素,使路线图在 PPT 中看起来更加专业、美观。 -3. 箭头的线宽应略粗于节点边框线宽,颜色可以采用与背景区分度较高的色彩,以保证“箭头非常明显、方向一眼可见”。 -4. 可以使用 <g> 分组对不同阶段、不同类型节点进行逻辑归类,便于整体调整和复用样式。 - -SVG 技术要求: -- SVG 以 <svg> 根节点开始,并包含必要的 width、height 和 viewBox 属性。 -- 整体风格要统一,适合作为论文技术路线图,最终会被插入 PPT 展示。 -- 尺寸改成“基于 viewBox 的自适应”,别写死 width/height 像素 - -风格要求: -- 满足: {style} 风格; - -svg代码的text要求: {lang} 语言!!!! - -请只根据上述 paper_idea 和要求进行设计,具体 SVG 输出规范见系统提示。 -""" - - # 系统层提示:严格约束输出为 {"svg_code": "xxx"},并强调“复杂、美观、箭头明显” - system_prompt_for_technical_route_desc_generator = """ -你是一个技术路线图设计助手。你的任务是: - -1. 从用户提供的论文研究想法(paper_idea)中抽取关键技术步骤、阶段和模块之间的依赖关系。 -2. 结合用户在任务提示中提供的整体风格描述(style),设计一个结构清晰、信息量相对丰富、视觉上美观的技术路线图。 -3. 使用 SVG 代码来表示该路线图,要求节点层次分明、阶段划分清楚、箭头粗细和颜色足够明显,使流程方向一眼可见,适合直接用于 PPT 展示。 - -输出格式要求(非常重要): -- 你必须仅输出一个严格的 JSON 对象,形如: - {"svg_code": "<svg ...>...</svg>"} -- 不要输出任何额外文字、注释、解释或 markdown 代码块标记。 -- JSON 中只能有一个键:svg_code。 -- svg_code 的值是完整的 SVG 源代码字符串: - - 以 <svg ...> 开始,以 </svg> 结束。 - - 包含 width, height, viewBox 等基本属性。 - - 所有双引号必须正确转义,以保证整个 JSON 可被标准 JSON 解析器解析。 - - 换行可以使用 \\n 进行转义。 - -SVG 内容设计规范(在不影响 JSON 解析的前提下,兼顾复杂度、美观和箭头可读性): -- 元素与布局: - - 以圆角矩形等简单图形作为主要步骤节点,可以配合阶段背景块和少量其他形状体现层次结构。 - - 使用 <line> 或 <path> 表示箭头,线条应相对粗一些,并带有清晰的箭头头部(可通过 marker 或简单三角形 path 实现),确保“箭头非常明显”。 - - 支持多阶段(3~5 阶段)布局,节点数量可以适度偏多,但要通过合理的对齐和间距保持整体清晰。 - -- 颜色与风格: - - 如果是卡通风格,色系用浅色系列,但是文字要深色; - - 如果是写实风格,颜色要深一些,以突出重点,多用灰白色;字体黑色; - - 箭头颜色和节点边框,以及文字 颜色应与背景产生清晰对比,保证流程方向一目了然。 - - 可以使用背景分区、阶段色带等方式增强层次感,但应避免过度复杂的滤镜导致视觉噪音。 - -- 文本与标注: - - 每个节点都需要有对应文字说明,可以放在节点内部(简短短语)或节点附近(上/下/左/右),保持整体风格一致。 - - 可以在图的上方添加一个整体标题和阶段标题,但避免大段长文本。 - - 使用合适字号和行间距,保证在 PPT 中阅读清晰。 - -- 复杂度与可读性: - - 可以包含较多节点和箭头来体现完整的技术流程,但要避免元素无序堆叠。 - - 通过对齐、分组、重复使用样式等方式保持视觉统一。 - - 避免关键连线被遮挡或重叠,确保每个阶段的主路径清楚易懂。 - -- 尺寸改成“基于 viewBox 的自适应”,别写死 width/height 像素 - -请严格遵守上述 JSON 输出要求,仅返回包含 svg_code 的 JSON 对象。 -""" - - -# ------------------------------------------------------------------ # -# 3. Technical Route BW SVG Generator (template-based) -# ------------------------------------------------------------------ # -class TechnicalRouteBWSvgGenerator: - system_prompt_for_technical_route_bw_svg_generator = """ -You are a Technical Route SVG Generator. Output strictly in JSON format: -{"svg_code": "<svg ...>...</svg>"} - -STRICT REQUIREMENTS: -1) Output only a JSON object, no extra text or Markdown. -2) SVG must include viewBox; width/height should be "100%". -3) Generate black/white/grayscale version only: fill and stroke can only use black/white/gray. -4) I have provided a template SVG code. Analyze its structure, layout, node shapes, arrow styles, and follow its design style while adjusting node count and positions as needed. -5) If validation_feedback is provided, fix the issues mentioned. -6) **CRITICAL LANGUAGE REQUIREMENT**: The text content in SVG must be in the language specified by the user. If user specifies "EN" or "English", ALL text labels must be in English. If user specifies "ZH" or "Chinese", ALL text labels must be in Chinese. -7) **CRITICAL FONT REQUIREMENT**: - - If generating Chinese text, you MUST use Chinese-friendly fonts in ALL text elements - - Use: font-family="Noto Sans CJK SC, Microsoft YaHei, SimHei, SimSun, sans-serif" - - NEVER use Arial, Helvetica, Times, or Courier for Chinese text - - Apply this font-family attribute to EVERY <text> element containing Chinese characters -""" - - task_prompt_for_technical_route_bw_svg_generator = """ -**IMPORTANT**: I have provided a template SVG code. Please analyze: -- Overall layout structure (top-to-bottom / left-to-right) -- Stage division (grouping and layering) -- Node shapes and sizes (rect, circle elements) -- Arrow thickness and styles (line, path, marker definitions) -- Text positions and sizes (text element coordinates and font sizes) -- viewBox and coordinate system -Then generate a new SVG based on this template style. - -**Template SVG Code**: -{template_svg_code} - -**Paper Content (paper_idea)**: -{paper_idea} - -**Validation Feedback (if any)**: -{validation_feedback} - -**OUTPUT LANGUAGE: {lang}** -CRITICAL: ALL text labels in the SVG MUST be written in {lang}. -- If {lang} is "EN" or "English": Use English for ALL text (e.g., "Data Processing", "Model Training", "Feature Extraction") -- If {lang} is "ZH" or "Chinese": Use Chinese for ALL text (e.g., "数据处理", "模型训练", "特征提取") -DO NOT mix languages. Every single text element must follow this language requirement. - -Output only JSON {{"svg_code": "..."}}. -""" - - -# ------------------------------------------------------------------ # -# 4. Technical Route Colorize SVG (palette-based) -# ------------------------------------------------------------------ # -class TechnicalRouteColorizeSvg: - system_prompt_for_technical_route_colorize_svg = """ -你是 SVG 上色器。输入是一份黑白 SVG 和色卡配置。 - -输出要求: -1) 仅输出一个 JSON 对象:{"svg_code": "<svg ...>...</svg>"}。 -2) 不改变任何几何结构/坐标/path d/文字内容,只允许修改 fill/stroke/style/class。 -3) 同类型或同层级内容使用同一颜色。 -4) 颜色仅从色卡提供的 colors/level_colors/arrow_color/text_color 中选择。 -5) 若提供 validation_feedback,请修复其中问题。 -""" - - task_prompt_for_technical_route_colorize_svg = """ -黑白 SVG: -{bw_svg_code} - -色卡配置(JSON): -{palette_json} - -彩色模板参考(可选,用于了解配色风格): -{color_template_svg} - -校验反馈(若有): -{validation_feedback} - -请参考彩色模板的配色风格,完成上色并输出 JSON {"svg_code": "..."}。 -""" - - # ------------------------------------------------------------------ # - # 2. SvgBgCleaner - svg_bg_cleaner 相关提示词 - # ------------------------------------------------------------------ # - task_prompt_for_svg_bg_cleaner = """ -下面给出一段完整的 SVG 源代码(包含文本和图形): - -{svg_code} - -你的任务是: -1. 在不改变图形布局、几何结构和视觉风格的前提下,删除或清空所有“文本相关内容”, - 只保留图形元素(如矩形、圆角矩形、圆、折线、路径、箭头、背景块、连线等)。 -2. 文本相关内容包括但不限于: - - 所有 <text> ... </text> 元素; - - 所有 <tspan> ... </tspan> 元素; - - 所有 <title> ... 元素; - - 以及任何仅用于呈现文字的 SVG 元素(如果你能可靠识别)。 - -具体要求: -- 不要改变 根元素的 width、height、viewBox 等属性; -- 不要移动或缩放任何非文本图形元素; -- 可以在删除文本元素后,适当移除仅用于文本的空 分组(如果明显不再包含任何子元素); -- 渐变、滤镜、marker(箭头定义)、clipPath 等 “非文本” 定义一律保留; -- 目标是得到一份“纯背景 / 纯图形”的 SVG,用于后续切图和 MinerU 分割。 - -输出格式要求(非常重要): -- 你必须仅输出一个严格的 JSON 对象,形如: - {"svg_bg_code": "..."} -- 不要输出任何额外文字、注释、解释或 markdown 代码块标记。 -- JSON 中只能有一个键:svg_bg_code。 -- svg_bg_code 的值是完整的 SVG 源代码字符串: - - 以 开始,以 结束; - - 不再包含任何 标签(大小写不限); - - 其它图形结构、渐变、滤镜、marker 等应尽可能保持不变; - - 所有双引号必须正确转义,以保证整个 JSON 可被标准 JSON 解析器解析; - - 换行可以使用 \\n 进行转义,也可以直接内联成一行,只要保证是合法 XML。 -""" - - system_prompt_for_svg_bg_cleaner = """ -你是一个 SVG 背景清洗助手,专门负责从完整 SVG 中删除所有文本相关元素,仅保留图形和装饰结构。 - -你的行为准则: -1. 严格遵守输出格式要求,只返回一个 JSON 对象:{"svg_bg_code": "..."}。 -2. 不要输出 markdown 代码块标记(例如 ```svg 或 ```),也不要添加多余的说明。 -3. 对于输入的 SVG: - - 保留所有非文本图形元素(rect, circle, ellipse, path, polyline, polygon, line, g 等); - - 保留各种视觉定义(defs 内的 gradient、pattern、filter、marker 等),除非它们只被文本使用且你能非常确定可以安全删除; - - 删除所有文本相关元素(text, tspan, title 以及其它仅用于显示文字的元素)。 -4. 输出的 svg_bg_code 应该是一个结构合法、可直接用 XML 解析的 SVG 文本,并且不再包含任何文本相关标签。 - -请根据用户提供的 svg_code 完成清洗,并严格返回 JSON。 -""" - - system_prompt_for_outline_agent = """ -你是一位拥有丰富学术汇报经验的PPT设计专家及大纲生成助手。你的核心任务是将一篇学术论文转化为一份逻辑清晰、视觉布局合理的PPT演示大纲。 - -请遵循以下严格规则: -1. **深度理解**:仔细阅读用户提供的论文内容,提取核心论点、实验数据和结论。 -2. **视觉导向**:在规划每一页PPT时,不仅要生成文字内容,必须明确指出该页是否需要展示论文中的特定插图(Images)或表格(Tables)。 -3. **布局建议**:为每一页提供具体的布局指导(例如:左文右图、上标题下表格、两栏对比等)。 -4. **格式严格**:输出必须且只能是标准的 JSON 格式数组。严禁包含 markdown 标记(如 ```json)、前言、后语或任何非 JSON 字符。 - -""" - - - task_prompt_for_outline_agent = """ -请根据以下提供的论文全文内容,生成一份详细的PPT演示文稿大纲。 - -**输入数据:** -论文内容: -{text_content} -{minueru_output} - -**约束条件:** -1. 目标PPT页数: {page_count} 页。 -2. 整体结构应该是有开始有结束,第一页应该就是ppt的主题 和 汇报人!!!!不需要额外的内容!!! -3. 最后一页得是致谢; -4. 返回论文内容一致的语言; - -**输出格式要求(JSON Array):** -请返回一个 JSON 数组,数组中每个对象代表一页PPT,结构如下: -- `title`: 该页PPT的标题。 -- `layout_description`: 详细的版面布局描述(例如:"左侧列出三个关键挑战点,右侧放置流程图")。 -- `key_points`: 一个包含多个关键要点的字符串列表(List),用于PPT正文展示。 -- `asset_ref`: 如果该页需要展示论文中的原图或表格,请提名或路径取其文件(例如 "Table_2", "images/architecture.png"),并且只能1 个 asset;如果不需要引用原图,请填 null。 - - -**示例输出结构 ** -!!!必须返回 {language} 语言!!! -[ - {{ - "title": "研究背景:大语言模型的幻觉问题", - "layout_description": "左侧文字介绍幻觉定义,右侧展示幻觉示例图,图片居中。", - "key_points": [ - "大语言模型在生成长文本时常出现事实性错误。", - "现有检索增强生成(RAG)方法的局限性。", - "本研究旨在解决上下文一致性问题。" - ...... - ], - "asset_ref": "images/xxx.png" - }}, - ...... - {{ - "title": "实验结果", - "layout_description": "顶部为标题,中间大幅展示架构图,底部放置关键步骤的简要说明。中间放表格Table_2数据。", - "key_points": [ - "阶段一:查询重写与扩展。", - "阶段二:基于相关性的文档过滤。", - "阶段三:生成与验证循环。" - ...... - ], - "asset_ref": "Table_2" - }} -] -""" - - system_prompt_for_outline_refine_agent = """ -你是一位拥有丰富学术汇报经验的 PPT 设计专家及大纲编辑助手。你的核心任务是:在不改变页数与顺序的前提下,基于用户反馈与论文内容,对已有 PPT 大纲进行更精准、更完善的改写与补充。 - -请遵循以下严格规则: -1. **改内容**:仅允许修改每页内容字段:`title` / `layout_description` / `key_points`。 -2. **保留引用**:默认保留 `asset_ref`(以及其它非内容字段),除非用户反馈明确要求修改。 -3. **反幻觉**:禁止编造论文中不存在的具体事实、数值、指标、结论或对比结果。若原文未提供支撑信息,只能做结构化补充(例如补充讲述维度/表达更完整),不能捏造细节。 -4. **格式严格**:输出必须且只能是标准 JSON 数组。严禁包含 markdown 标记(如 ```json)、前言、后语或任何非 JSON 字符。 -5. **最小必要修改**:仅修改反馈涉及的页面与要点;未涉及页面保持原样。 -""" - - task_prompt_for_outline_refine_agent = """ -请根据以下提供的论文内容、当前大纲以及用户反馈,对大纲进行“只改内容”的修订与完善。 - -**输入数据:** -论文内容: -{text_content} -{minueru_output} - -当前大纲(JSON Array): -{pagecontent} - -用户反馈: -{outline_feedback} - -**约束条件:** -1. `asset_ref` 默认保留;除非用户反馈明确要求修改。 -2. 返回论文内容一致的语言; -3. 若用户提到“第 N 页”,按 1-based 页码理解:输入数组第 1 个对象为“第 1 页” 或 “第一页”。 -4. 如果需要添加内容,则必须严格参考论文内容,绝对不能添加论文中不存在的内容或数据!! - -**输出格式要求(JSON Array):** -请返回一个 JSON 数组,数组中每个对象代表一页PPT,结构如下: -- `title`: 该页PPT的标题。 -- `layout_description`: 详细的版面布局描述(例如:"左侧列出三个关键点,右侧放置流程图")。 -- `key_points`: 一个包含多个关键要点的字符串列表(List)。 -- `asset_ref`: 默认保留原值(除非反馈明确要求修改)。 - -**示例输出结构 ** -!!!必须返回 {language} 语言!!! -[ - {{ - "title": "xxx", - "layout_description": "xxx", - "key_points": ["xxx", "xxx"], - "asset_ref": null - }} -] -""" - - system_prompt_for_table_extractor=""" - 你是一名前端代码专家 - """ - - task_prompt_for_table_extractor=""" - - 根据论文内容: - {minueru_output} - - 找到 {table_num} 的表格 ,的数据 和 内容,数据和caption; - - 1.根据表格内容数据改写成html代码; - 2.如果没有提供论文内容,则直接创建一个空表的html代码; - - 返回纯json内容,不要有任何markdown格式的标记,也不要有任何说明文字。 - 代码也不要有任何注释; - - json格式为: - {{"html_code": "表格html代码"}} - """ - - system_prompt_for_deep_research_agent = """ -You are a Deep Research Assistant. Your task is to conduct a "Deep Research" on a given [Topic] and generate a comprehensive, structured, and detailed research report. - -Your report should serve as the foundation for creating a professional PowerPoint presentation. -Therefore, the content must be: -1. **Comprehensive**: Cover all key aspects, background, methodology, current trends, and future directions related to the topic. -2. **Structured**: Organize with clear headings (Introduction, Key Concepts, Analysis, Conclusion, etc.). -4. **Academic/Professional**: Maintain a formal and objective tone. - -If the input is just a short topic string, expand it into a full article. -""" - - task_prompt_for_deep_research_agent = """ -[Topic]: -{text_content} - -[Instructions]: -Please perform a deep research simulation on the above topic and output a detailed research report. -Ensure the content is rich and logically organized !!!! - -[Language]: {language} -""" diff --git a/dataflow_agent/promptstemplates/resources/pt_test_graph_repo.py b/dataflow_agent/promptstemplates/resources/pt_test_graph_repo.py deleted file mode 100644 index a16e5c7..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_test_graph_repo.py +++ /dev/null @@ -1,107 +0,0 @@ -# Prompt templates for Nano-Banana Pro based paper illustration generation. -# -# - NANO_BANANA_VISUAL_SCHEMA_PROMPT: Step 1, use LLM to construct [VISUAL SCHEMA] -# - NANO_BANANA_RENDER_PROMPT: Step 2, use Nano-Banana Pro (or similar) to render pixels -# - NANO_BANANA_EDITING_TIPS: Step 3, natural-language editing hints - - -NANO_BANANA_VISUAL_SCHEMA_PROMPT = """# Role -你是一位 CVPR/NeurIPS 顶刊的**视觉架构师**。你的核心能力是将抽象的论文逻辑转化为**具体的、结构化的、几何级的视觉指令**。 - -# Objective -阅读我提供的论文内容,输出一份 **[VISUAL SCHEMA]**。这份 Schema 将被直接发送给 AI 绘图模型,因此必须使用**强硬的物理描述**。 - -# Phase 1: Layout Strategy Selector (关键步骤:布局决策) -在生成 Schema 之前,请先分析论文逻辑,从以下**布局原型**中选择最合适的一个(或组合): -1. **Linear Pipeline**: 左→右流向 (适合 Data Processing, Encoding-Decoding)。 -2. **Cyclic/Iterative**: 中心包含循环箭头 (适合 Optimization, RL, Feedback Loops)。 -3. **Hierarchical Stack**: 上→下或下→上堆叠 (适合 Multiscale features, Tree structures)。 -4. **Parallel/Dual-Stream**: 上下平行的双流结构 (适合 Multi-modal fusion, Contrastive Learning)。 -5. **Central Hub**: 一个核心模块连接四周组件 (适合 Agent-Environment, Knowledge Graphs)。 - -# Phase 2: Schema Generation Rules -1. **Dynamic Zoning**: 根据选择的布局,定义 2-5 个物理区域 (Zones)。不要局限于 3 个。 -2. **Internal Visualization**: 必须定义每个区域内部的“物体” (Icons, Grids, Trees),禁止使用抽象概念。 -3. **Explicit Connections**: 如果是循环过程,必须明确描述 "Curved arrow looping back from Zone X to Zone Y"。 - -# Output Format (The Golden Schema) -请严格遵守以下 Markdown 结构输出: - ----BEGIN PROMPT--- - -[Style & Meta-Instructions] -High-fidelity scientific schematic, technical vector illustration, clean white background, distinct boundaries, academic textbook style. High resolution 4k, strictly 2D flat design with subtle isometric elements. - -[LAYOUT CONFIGURATION] -* **Selected Layout**: [例如:Cyclic Iterative Process with 3 Nodes] -* **Composition Logic**: [例如:A central triangular feedback loop surrounded by input/output panels] -* **Color Palette**: Professional Pastel (Azure Blue, Slate Grey, Coral Orange, Mint Green). - -[ZONE 1: LOCATION - LABEL] -* **Container**: [形状描述, e.g., Top-Left Panel] -* **Visual Structure**: [具体描述, e.g., A stack of documents] -* **Key Text Labels**: "[Text 1]" - -[ZONE 2: LOCATION - LABEL] -* **Container**: [形状描述, e.g., Central Circular Engine] -* **Visual Structure**: [具体描述, e.g., A clockwise loop connecting 3 internal modules: A (Gear), B (Graph), C (Filter)] -* **Key Text Labels**: "[Text 2]", "[Text 3]" - -[ZONE 3: LOCATION - LABEL] -... (Add Zone 4/5 if necessary based on layout) - -[CONNECTIONS] -1. [描述连接线, e.g., A curved dotted arrow looping from Zone 2 back to Zone 1 labeled "Feedback"] -2. [描述连接线, e.g., A wide flow arrow from Zone 2 to Zone 3] - ----END PROMPT--- - -# Input Data -[在此处粘贴你的论文内容] -""" - - -NANO_BANANA_RENDER_PROMPT = """**Style Reference & Execution Instructions:** - -1. **Art Style (Visio/Illustrator Aesthetic):** - Generate a **professional academic architecture diagram** suitable for a top-tier computer science paper (CVPR/NeurIPS). - * **Visuals:** Flat vector graphics, distinct geometric shapes, clean thin outlines, and soft pastel fills (Azure Blue, Slate Grey, Coral Orange). - * **Layout:** Strictly follow the spatial arrangement defined below. - * **Vibe:** Technical, precise, clean white background. NOT hand-drawn, NOT photorealistic, NOT 3D render, NO shadows/shading. - -2. **CRITICAL TEXT CONSTRAINTS (Read Carefully):** - * **DO NOT render meta-labels:** Do not write words like "ZONE 1", "LAYOUT CONFIGURATION", "Input", "Output", or "Container" inside the image. These are structural instructions for YOU, not text for the image. - * **ONLY render "Key Text Labels":** Only text inside double quotes (e.g., "[Text]") listed under "Key Text Labels" should appear in the diagram. - * **Font:** Use a clean, bold Sans-Serif font (like Roboto or Helvetica) for all labels. - -3. **Visual Schema Execution:** - Translate the following structural blueprint into the final image: - -[在此处直接粘贴 Step 1 生成的 ---BEGIN PROMPT--- ... ---END PROMPT--- 内容(包含方括号内的英文)] -""" - - -NANO_BANANA_EDITING_TIPS = """Step 3: Interactive Editing & Refinement (The Editor) - -当你拿到步骤二生成的初稿后,如果整体布局满意,但细节或风格有问题,优先使用自然语言编辑,而不是重新生成。 - -典型编辑指令示例: - -1. 修改图标: - - "Change the 'Gear' icon in the center to a 'Neural Network' icon." - - "Replace the robot head with a simple document symbol." - -2. 调整颜色: - - "Make the background of the left panel pure white instead of light blue." - - "Change the orange arrows to dark grey." - -3. 统一风格: - - "Make all lines thinner and cleaner." - - "Remove the shading effect, make it completely flat 2D." - -4. 修正文字: - - "Correct the text 'ZONNE' to 'ZONE'." - - 如果文字错误太多,可以直接:"Remove the text labels." - -如果整体布局错误(例如本该是循环结构却画成直线),应回到 Step 1,重新调整 [VISUAL SCHEMA] 中的布局配置和内部描述,而不是尝试用编辑命令修修补补。 -""" diff --git a/dataflow_agent/promptstemplates/resources/pt_zz_paper2page_outline_repo.py b/dataflow_agent/promptstemplates/resources/pt_zz_paper2page_outline_repo.py deleted file mode 100644 index 8293ee1..0000000 --- a/dataflow_agent/promptstemplates/resources/pt_zz_paper2page_outline_repo.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Paper2Page / Paper2PPT 全套提示词(与 Paper2Any 一致) -文件名 zz 保证最后加载,覆盖 pt_technical 中的 outline / outline_refine,使 paper2page_content 与 paper2ppt 全流程与 Paper2Any 一致。 -""" - - -class Paper2PageOutlinePrompts: - # ---------- outline_agent(paper2page_content 用)---------- - system_prompt_for_outline_agent = """ -你是一位专业的学术汇报 PPT 大纲生成专家。 -你的任务是根据输入资料生成结构化 PPT 大纲(JSON 数组)。 -输出必须严格为 JSON,不要包含任何额外文字或 Markdown。 -""" - - task_prompt_for_outline_agent = """ -输入: -- 文档解析内容(可能为空;多来源时已按「来源1」「来源2」分段):{minueru_output} -- 文本内容(可能为空):{text_content} - -要求: -1) 直接基于上述文档/文本内容生成大纲。 -2) 输出页数:{page_count} 页。 -3) 输出语言:{language}。 -4) 每页必须包含字段:title, layout_description, key_points(list), asset_ref(null)。 -5) 第一页为标题页,最后一页为致谢。 - -输出格式(JSON 数组): -[ - { - "title": "...", - "layout_description": "...", - "key_points": ["..."], - "asset_ref": null - } -] -""" - - # ---------- outline_refine_agent(与 Paper2Any pt_technical_route 一致)---------- - system_prompt_for_outline_refine_agent = """ -你是一位拥有丰富学术汇报经验的 PPT 设计专家及大纲编辑助手。你的核心任务是:在不改变页数与顺序的前提下,基于用户反馈与论文内容,对已有 PPT 大纲进行更精准、更完善的改写与补充。 - -请遵循以下严格规则: -1. **改内容**:仅允许修改每页内容字段:`title` / `layout_description` / `key_points`。 -2. **保留引用**:默认保留 `asset_ref`(以及其它非内容字段),除非用户反馈明确要求修改。 -3. **反幻觉**:禁止编造论文中不存在的具体事实、数值、指标、结论或对比结果。若原文未提供支撑信息,只能做结构化补充(例如补充讲述维度/表达更完整),不能捏造细节。 -4. **格式严格**:输出必须且只能是标准 JSON 数组。严禁包含 markdown 标记(如 ```json)、前言、后语或任何非 JSON 字符。 -5. **最小必要修改**:仅修改反馈涉及的页面与要点;未涉及页面保持原样。 -""" - - task_prompt_for_outline_refine_agent = """ -请根据以下提供的论文内容、当前大纲以及用户反馈,对大纲进行“只改内容”的修订与完善。 - -**输入数据:** -论文内容: -{text_content} -{minueru_output} - -当前大纲(JSON Array): -{pagecontent} - -用户反馈: -{outline_feedback} - -**约束条件:** -1. `asset_ref` 默认保留;除非用户反馈明确要求修改。 -2. 返回论文内容一致的语言; -3. 若用户提到“第 N 页”,按 1-based 页码理解:输入数组第 1 个对象为“第 1 页” 或 “第一页”。 -4. 如果需要添加内容,则必须严格参考论文内容,绝对不能添加论文中不存在的内容或数据!! - -**输出格式要求(JSON Array):** -请返回一个 JSON 数组,数组中每个对象代表一页PPT,结构如下: -- `title`: 该页PPT的标题。 -- `layout_description`: 详细的版面布局描述(例如:"左侧列出三个关键点,右侧放置流程图")。 -- `key_points`: 一个包含多个关键要点的字符串列表(List)。 -- `asset_ref`: 默认保留原值(除非反馈明确要求修改)。 - -!!!必须返回 {language} 语言!!! -[ - {{ - "title": "xxx", - "layout_description": "xxx", - "key_points": ["xxx", "xxx"], - "asset_ref": null - }} -] -""" diff --git a/dataflow_agent/resources/taskinfo.yaml b/dataflow_agent/resources/taskinfo.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/states/__init__.py b/dataflow_agent/states/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/states/test_graph_state.py b/dataflow_agent/states/test_graph_state.py deleted file mode 100644 index 453459c..0000000 --- a/dataflow_agent/states/test_graph_state.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -TestGraph State and Request -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-01 20:17:07 - -本文件由 `dfa create --state_name test_graph` 自动生成。 -用于创建继承MainState和MainRequest的自定义state和request。 -""" - -from dataclasses import dataclass, field -from dataflow_agent.state import MainState, MainRequest - - -# ==================== TestGraph Request ==================== -@dataclass -class TestGraphRequest(MainRequest): - """TestGraph任务的Request,继承自MainRequest""" - pass - - -# ==================== TestGraph State ==================== -@dataclass -class TestGraphState(MainState): - """TestGraph任务的State,继承自MainState""" - - # 重写request类型为TestGraphRequest - request: TestGraphRequest = field(default_factory=TestGraphRequest) \ No newline at end of file diff --git a/dataflow_agent/storage/__init__.py b/dataflow_agent/storage/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/storage/storage_service.py b/dataflow_agent/storage/storage_service.py deleted file mode 100644 index 8171f14..0000000 --- a/dataflow_agent/storage/storage_service.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -""" -storage_service.py ── Enhanced FileStorage with sampling utilities -Author : [Zhou Liu] -License : MIT -Created : 2024-07-02 - -This module provides the `SampleFileStorage` class, which extends standard file storage -with advanced sampling capabilities, including: - -* Efficient statistical and reservoir sampling from large datasets -* Built-in support for sample size estimation (Cochran formula) for both proportions and means -* Streamed and in-memory data access for pandas DataFrames -* Helper methods for counting, field listing, and full/streamed record fetching - -Designed for rapid data inspection, human evaluation, and scalable sampling tasks. -Thread-safety and memory usage depend on the base FileStorage implementation and dataset size. - -""" - -import random -from typing import Iterator, Dict, Any, List, Tuple, Literal -import math -import pandas as pd -from dataflow.utils.storage import FileStorage -class SampleFileStorage(FileStorage): - """ - A FileStorage that supports statistical / reservoir sampling. - - Compared with `read()`, the `rsample()` method lets you - draw a subset of records with statistical guarantees, which is - useful for quick data inspection or human evaluation. - - Three sampling modes are supported: - - 1. manual – fixed-size reservoir `k` - 2. proportion – sample size for estimating a **proportion** - (e.g. accuracy) with margin ± *e* - 3. mean – sample size for estimating a **mean** - (e.g. rating) with margin ± *e* - """ - - def _load_as_dataframe(self) -> pd.DataFrame: - """ - Always fetch the CURRENT step as a pandas DataFrame. - (Reuses the .read() logic in the parent class.) - """ - return self.read("dataframe") - - def _stream_dicts(self) -> Iterator[Dict[str, Any]]: - """Iterate row-by-row as dict without storing the whole dataset.""" - df = self._load_as_dataframe() - for _, row in df.iterrows(): - yield row.to_dict() - - def count(self) -> int: - """Total number of records -- utilised by the sample-size formulae.""" - return len(self._load_as_dataframe()) - - def get_fields(self) -> List[str]: - """Get column names / JSON keys in the dataset.""" - df = self._load_as_dataframe() - return list(df.columns) - - def fetch_all(self) -> List[Dict[str, Any]]: - """Load *all* records into memory (avoid for very large files).""" - return self._load_as_dataframe().to_dict(orient="records") - - def fetch_stream(self) -> Iterator[Dict[str, Any]]: - """Row-wise generator (preferred for very large datasets).""" - yield from self._stream_dicts() - - @staticmethod - def _Z(conf_level: float) -> float: - """Z-score lookup for common confidence levels.""" - return {0.9: 1.645, 0.95: 1.96, 0.99: 2.576}[conf_level] - - @staticmethod - def sample_size_proportion(N: int, - conf_level: float = 0.95, - margin: float = 0.03, - p: float = 0.5) -> int: - """ - Cochran formula (finite-population correction) for a proportion. - - N : population size - p : expected proportion (worst-case 0.5 if unknown) - """ - Z = SampleFileStorage._Z(conf_level) - n0 = Z ** 2 * p * (1 - p) / margin ** 2 - return math.ceil(n0 / (1 + (n0 - 1) / N)) - - @staticmethod - def sample_size_mean(N: int, - conf_level: float = 0.95, - margin: float = 1.0, - sigma: float = 10.0) -> int: - """ - Cochran formula for estimating a mean. - sigma : population standard deviation (use pilot estimate) - """ - Z = SampleFileStorage._Z(conf_level) - n0 = Z ** 2 * sigma ** 2 / margin ** 2 - return math.ceil(n0 / (1 + (n0 - 1) / N)) - - def rsample( - self, - mode: Literal["manual", "proportion", "mean"] = "manual", - *, - k: int | None = None, - conf_level: float = 0.95, - margin: float = 0.03, - p: float = 0.5, - sigma: float = 10.0, - ) -> Tuple[List[Dict[str, Any]], int]: - """ - Reservoir-sample the current dataset. - - Returns (sample_records, k) - """ - N = self.count() - - # Decide sample size ------------------------------------------------- - if mode == "manual": - if not k or k <= 0: - raise ValueError("manual mode requires a positive integer k") - elif mode == "proportion": - k = self.sample_size_proportion(N, conf_level, margin, p) - elif mode == "mean": - k = self.sample_size_mean(N, conf_level, margin, sigma) - else: - raise ValueError('mode must be "manual", "proportion", or "mean"') - - self.logger.info(f"Sampling k={k} from N={N} (mode={mode})") - - # Simple reservoir algorithm ---------------------------------------- - reservoir: List[Dict[str, Any]] = [] - for t, rec in enumerate(self.fetch_stream(), start=1): - if t <= k: - reservoir.append(rec) - else: - j = random.randrange(t) - if j < k: # replace with probability k/t - reservoir[j] = rec - - return reservoir, k \ No newline at end of file diff --git a/dataflow_agent/templates/agent.py.jinja b/dataflow_agent/templates/agent.py.jinja deleted file mode 100644 index 4a9c4b7..0000000 --- a/dataflow_agent/templates/agent.py.jinja +++ /dev/null @@ -1,165 +0,0 @@ -""" -{{ agent_name_camel }} agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: {{ timestamp }} -生成位置: dataflow_agent/agentroles/common_agents/{{ agent_name_snake }}_agent.py - -本文件由 `dfa create --agent_name {{ agent_name }}` 自动生成。 -1. 填写 prompt-template 名称 -2. 根据需要完成 get_task_prompt_params / update_state_result -""" - -from __future__ import annotations - -from typing import Any, Dict, Optional - -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -# ---------------------------------------------------------------------- -# Agent Definition -# ---------------------------------------------------------------------- -@register("{{ agent_name }}") -class {{ agent_name_camel }}(BaseAgent): - """TODO: 描述 {{ agent_name }} 的职责""" - - # ---------- 工厂 ---------- - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - # ---------- 基本配置 ---------- - @property - def role_name(self) -> str: # noqa: D401 - return "{{ agent_name }}" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "system_prompt_for_{{ agent_name }}" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为真实的模板 id - return "task_prompt_for_{{ agent_name }}" - - # ---------- Prompt 参数 ---------- - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """根据前置工具结果构造 prompt 参数 - 提示词中的占位符: - return { - 'text2img_prompt': pre_tool_results.get('prompt', ''), - 'image_size': pre_tool_results.get('size', '512x512'), - 'num_images': pre_tool_results.get('num_images', 1), - } - """ - # TODO: 按需补充 - return {} - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """若调用方未显式传入,返回默认前置工具结果""" - return {} - - # ---------- 结果写回 ---------- - def update_state_result( - self, - state: MainState, - result: Dict[str, Any], - pre_tool_results: Dict[str, Any], - ): - """将推理结果写回 MainState,可按需重写""" - - state.xx = result - - super().update_state_result(state, result, pre_tool_results) - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def {{ agent_name_snake }}( - state: MainState, - model_name: Optional[str] = None, - tool_manager: Optional[ToolManager] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """{{ agent_name }} 的异步入口 - - Args: - state: 主状态对象 - model_name: 模型名称,如 "gpt-4" - tool_manager: 工具管理器实例 - temperature: 采样温度,控制随机性 (0.0-1.0) - max_tokens: 最大生成token数 - tool_mode: 工具调用模式 ("auto", "none", "required") - react_mode: 是否启用ReAct推理模式 - react_max_retries: ReAct模式下最大重试次数 - parser_type: 解析器类型 ("json", "xml", "text"),这个允许你在提示词中定义LLM不同的返回,xml还是json,还是直出; - parser_config: 解析器配置字典(如XML的root_tag) - use_vlm: 是否使用视觉语言模型,使用了视觉模型,其余的参数失效; - vlm_config: VLM配置字典 - use_agent: 是否使用agent模式 - **kwargs: 其他传递给execute的参数 - - Returns: - 更新后的MainState对象 - """ - agent = {{ agent_name_camel }}( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - ) - return await agent.execute(state, use_agent=use_agent, **kwargs) - - -def create_{{ agent_name_snake }}( - tool_manager: Optional[ToolManager] = None, - model_name: Optional[str] = None, - temperature: float = 0.0, - max_tokens: int = 4096, - tool_mode: str = "auto", - react_mode: bool = False, - react_max_retries: int = 3, - parser_type: str = "json", - parser_config: Optional[Dict[str, Any]] = None, - use_vlm: bool = False, - vlm_config: Optional[Dict[str, Any]] = None, - **kwargs, -) -> {{ agent_name_camel }}: - return {{ agent_name_camel }}.create( - tool_manager=tool_manager, - model_name=model_name, - temperature=temperature, - max_tokens=max_tokens, - tool_mode=tool_mode, - react_mode=react_mode, - react_max_retries=react_max_retries, - parser_type=parser_type, - parser_config=parser_config, - use_vlm=use_vlm, - vlm_config=vlm_config, - **kwargs, - ) diff --git a/dataflow_agent/templates/agent_as_tool_name.py.jinja b/dataflow_agent/templates/agent_as_tool_name.py.jinja deleted file mode 100644 index 7d93865..0000000 --- a/dataflow_agent/templates/agent_as_tool_name.py.jinja +++ /dev/null @@ -1,99 +0,0 @@ -""" -{{ agent_name_camel }} Agent -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: {{ timestamp }} -生成位置: dataflow_agent/agentroles/common_agents/{{ agent_name_snake }}_agent.py - -本文件由 `dfa create --agent_as_tool_name {{ agent_name }}` 自动生成。 -""" - -from __future__ import annotations -from typing import Any, Dict, Optional -from dataflow_agent.state import MainState -from dataflow_agent.toolkits.tool_manager import ToolManager -from dataflow_agent.logger import get_logger -from dataflow_agent.agentroles.cores.base_agent import BaseAgent -from dataflow_agent.agentroles.cores.registry import register - -log = get_logger(__name__) - -@register("{{ agent_name }}") -class {{ agent_name_camel }}(BaseAgent): - """{{ agent_name_camel }} Agent - 支持作为工具被调用""" - - @classmethod - def create(cls, tool_manager: Optional[ToolManager] = None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - @property - def role_name(self) -> str: - return "{{ agent_name }}" - - @property - def system_prompt_template_name(self) -> str: - # TODO: 修改为实际的模板 ID - return "system_prompt_for_{{ agent_name }}" - - @property - def task_prompt_template_name(self) -> str: - # TODO: 修改为实际的模板 ID - return "task_prompt_for_{{ agent_name }}" - - def get_task_prompt_params(self, pre_tool_results: Dict[str, Any]) -> Dict[str, Any]: - """构造 prompt 参数""" - # TODO: 根据需要添加参数处理逻辑 - return pre_tool_results - - def get_default_pre_tool_results(self) -> Dict[str, Any]: - """默认参数(可选)""" - # TODO: 如需默认值,在此返回 - return {} - - # ==================== Agent-as-Tool 自定义(按需重写) ==================== - - # def get_tool_description(self) -> str: - # """工具描述(建议重写以提供更好的说明) - # - # 示例: - # return "用于分析和总结文本内容,支持多种格式输出" - # """ - # return super().get_tool_description() - - # def get_tool_args_schema(self) -> type[BaseModel]: - # """参数模式(按需重写以定义具体参数) - # - # 示例: - # from pydantic import BaseModel, Field - # - # class MyToolArgs(BaseModel): - # content: str = Field(description="要处理的内容") - # max_length: int = Field(default=500, description="最大长度") - # - # return MyToolArgs - # """ - # return super().get_tool_args_schema() - - # def prepare_tool_execution_params(self, **tool_kwargs) -> Dict[str, Any]: - # """参数转换(按需重写以映射参数名称) - # - # 示例: - # return { - # 'text': tool_kwargs.get('content'), # 重命名 - # 'max_len': tool_kwargs.get('max_length', 500) # 添加默认值 - # } - # """ - # return tool_kwargs - - -# ---------------------------------------------------------------------- -# Helper APIs -# ---------------------------------------------------------------------- -async def {{ agent_name_snake }}( - state: MainState, - tool_manager: Optional[ToolManager] = None, - use_agent: bool = False, - **kwargs, -) -> MainState: - """{{ agent_name }} 异步入口""" - agent = {{ agent_name_camel }}.create(tool_manager=tool_manager, **kwargs) - return await agent.execute(state, use_agent=use_agent, **kwargs) diff --git a/dataflow_agent/templates/gradio_page.py.jinja b/dataflow_agent/templates/gradio_page.py.jinja deleted file mode 100644 index 32c0fd1..0000000 --- a/dataflow_agent/templates/gradio_page.py.jinja +++ /dev/null @@ -1,63 +0,0 @@ -""" -Auto-generated on {{ timestamp }} -本文件由自动化模板生成。你可以在此基础上自定义 Gradio UI 组件与数据流执行函数。 -""" - -from dataflow_agent.state import DFRequest, DFState -from dataflow_agent.workflow import run_workflow -from dataflow_agent.logger import get_logger - -log = get_logger(__name__) - -import gradio as gr - -# ------------------- Gradio 页面组件定义 ------------------- -def create_{{ page_name_snake }}() -> gr.Blocks: - """ - 创建 {{ page_name }} 页面。 - - Returns: - gr.Blocks: Gradio 多组件页面对象。 - """ - with gr.Blocks() as page: - gr.Markdown("## {{ page_name }} (auto-generated)") - gr.Markdown("> 在这里添加你的组件 …") - # TODO: 添加更多 Gradio 组件,比如输入框、按钮、数据展示等 - return page - -# ------------------- 数据流工作流执行函数模板 ------------------- -async def run_xxx_pipeline( - # TODO: 添加必要的参数,例如 json_file, chat_api_url, apikey , 需要符合你的State和Request等 -): - """ - 执行 DataFlow Operator Usage 工作流。 - - 参数说明: - json_file (str): 输入数据文件路径(jsonl 格式)。 - chat_api_url (str): Chat API 的访问地址。 - apikey (str): OpenAI 或自定义大模型接口的 API Key。 - model (str, 可选): 使用的模型名称,默认为 'gpt-4o'。 - - 返回值: - DFState: 工作流的最终状态对象,包含产出数据与日志信息。 - """ - # TODO: 你可以在这里构造 DFRequest、DFState,并调用 run_workflow - # 示例: - # req = DFRequest( - # language=language, - # model=model, - # target="测试 pipeline 生成和执行", - # json_file=json_file, - # cache_dir=cache_dir, - # session_id=session_id, - # chat_api_url=chat_api_url, - # apikey=apikey, - # ) - # state = DFState( - # request=req, - # messages=[], - # matched_ops=matched_ops_with_params, - # ) - # final_state = await run_workflow("wf_xxx", state) - # return final_state - pass # 请补充实现 \ No newline at end of file diff --git a/dataflow_agent/templates/prompt_repo.py.jinja b/dataflow_agent/templates/prompt_repo.py.jinja deleted file mode 100644 index ff0f1da..0000000 --- a/dataflow_agent/templates/prompt_repo.py.jinja +++ /dev/null @@ -1,21 +0,0 @@ -""" -Prompt Templates for {{ prompt_name }} -Generated at: {{ timestamp }} -""" - -# --------------------------------------------------------------------------- # -# 1. {{ prompt_name_camel }} - {{ prompt_name }} 相关提示词 -# --------------------------------------------------------------------------- # -class {{ prompt_name_camel }}: - """ - {{ prompt_name }} 任务的提示词模板 - """ - - task_prompt_for_example = """ -Your task description here. -Input: {input_data} -""" - - system_prompt_for_example = """ -You are an AI assistant for {{ prompt_name }} tasks. -""" \ No newline at end of file diff --git a/dataflow_agent/templates/state_name.py.jinja b/dataflow_agent/templates/state_name.py.jinja deleted file mode 100644 index b07a99e..0000000 --- a/dataflow_agent/templates/state_name.py.jinja +++ /dev/null @@ -1,27 +0,0 @@ -""" -{{ state_name_camel }} State and Request -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: {{ timestamp }} - -本文件由 `dfa create --state_name {{ state_name }}` 自动生成。 -用于创建继承MainState和MainRequest的自定义state和request。 -""" - -from dataclasses import dataclass, field -from dataflow_agent.state import MainState, MainRequest - - -# ==================== {{ state_name_camel }} Request ==================== -@dataclass -class {{ state_name_camel }}Request(MainRequest): - """{{ state_name_camel }}任务的Request,继承自MainRequest""" - pass - - -# ==================== {{ state_name_camel }} State ==================== -@dataclass -class {{ state_name_camel }}State(MainState): - """{{ state_name_camel }}任务的State,继承自MainState""" - - # 重写request类型为{{ state_name_camel }}Request - request: {{ state_name_camel }}Request = field(default_factory={{ state_name_camel }}Request) \ No newline at end of file diff --git a/dataflow_agent/templates/test_workflow.py.jinja b/dataflow_agent/templates/test_workflow.py.jinja deleted file mode 100644 index f60788a..0000000 --- a/dataflow_agent/templates/test_workflow.py.jinja +++ /dev/null @@ -1,86 +0,0 @@ -""" -测试 {{ wf_name }} workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: {{ timestamp }} - -运行方式: - pytest tests/test_{{ wf_name_snake }}.py -v -s - 或直接: python tests/test_{{ wf_name_snake }}.py -""" - -from __future__ import annotations -import asyncio -import pytest - -# ------------ 依赖 ------------- -from dataflow_agent.states.xx_xx_xx import xxState, xxRequest -# from dataflow_agent.state import xxState, xxRequest -from dataflow_agent.workflow import run_workflow -# 如果使用了自定义 State,请替换上面的 xxState 导入: -# from dataflow_agent.state import YourCustomState -# -------------------------------- - - -# ============ 核心异步流程 ============ -async def run_{{ wf_name_snake }}_pipeline() -> xxState: - """ - 执行 {{ wf_name }} 工作流的测试流程 - """ - # TODO: 根据实际需求构造初始状态 - # 1) 如果使用自定义请求对象,在这里构造 - # req = YourRequest( - # param1="value1", - # param2="value2", - # ) - - # 2) 初始化状态 - state = xxState( - messages=[], - # request=req, # 如果有自定义请求 - ) - - # TODO: 可以在这里预设一些测试数据 - # state.user_input = "测试输入" - # state.agent_results = {} - - # 3) 通过注册中心执行工作流 - final_state: xxState = await run_workflow("{{ wf_name }}", state) - return final_state - - -# ============ pytest 入口 ============ -@pytest.mark.asyncio -async def test_{{ wf_name_snake }}_pipeline(): - """ - 测试 {{ wf_name }} 工作流的完整流程 - """ - final_state = await run_{{ wf_name_snake }}_pipeline() - - # TODO: 根据实际业务逻辑添加断言 - # 示例断言: - assert final_state is not None, "final_state 不应为 None" - assert hasattr(final_state, "agent_results"), "state 应包含 agent_results" - - # -- 检查特定节点的结果 -- - # assert "step1" in final_state.agent_results, "step1 应该执行" - # assert final_state.agent_results["step1"]["msg"] == "hello step1" - - # -- 检查 messages 或其他字段 -- - # assert len(final_state.messages) > 0, "应该有消息记录" - - # -- 调试输出,可按需保留 -- - print("\n=== agent_results ===") - print(final_state.agent_results) - - if hasattr(final_state, "messages") and final_state.messages: - print("\n=== messages ===") - for msg in final_state.messages: - print(f"- {msg}") - - -# ============ 直接 python 执行 ============ -if __name__ == "__main__": - """ - 允许直接运行此文件进行快速测试 - """ - asyncio.run(run_{{ wf_name_snake }}_pipeline()) \ No newline at end of file diff --git a/dataflow_agent/templates/workflow.py.jinja b/dataflow_agent/templates/workflow.py.jinja deleted file mode 100644 index 1372f0f..0000000 --- a/dataflow_agent/templates/workflow.py.jinja +++ /dev/null @@ -1,257 +0,0 @@ -""" -{{ wf_name }} workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: {{ timestamp }} - -1. 在 **TOOLS** 区域定义需要暴露给 Prompt 的前置工具 -2. 在 **NODES** 区域实现异步节点函数 (await-able) -3. 在 **EDGES** 区域声明有向边 -4. 最后返回 builder.compile() 或 GenericGraphBuilder -""" - -from __future__ import annotations -import json -from dataclasses import Field -from pydantic import BaseModel -from dataflow_agent.states.xxState import xxState -# from dataflow_agent.state import xxState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.workflow.registry import register -from dataflow_agent.agentroles import ( - create_agent, - create_simple_agent, - create_react_agent, - create_graph_agent, - create_vlm_agent, - SimpleConfig, - ReactConfig, - GraphConfig, - VLMConfig, - ExecutionMode, -) - -from dataflow_agent.toolkits.tool_manager import get_tool_manager -from langchain.tools import tool -from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode, tools_condition - -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger - -log = get_logger(__name__) - -@register("{{ wf_name }}") -def create_{{ wf_name_snake }}_graph() -> GenericGraphBuilder: # noqa: N802 - """ - Workflow factory: dfa run --wf {{ wf_name }} - """ - builder = GenericGraphBuilder(state_model=xxState, - entry_point="{{ entry }}") # 自行修改入口 - - # ---------------------------------------------------------------------- - # TOOLS (pre_tool definitions) - # ---------------------------------------------------------------------- - # 例: - @builder.pre_tool("purpose", "step1") - def _purpose(state: xxState): - return "这里放入字符串 / 数值 / 列表 / 字典等供 prompt 使用" - - # 后置工具就是让agent选择的工具,可以定制多个; - @builder.post_tool("step2") - @tool - def _example_post_tool(date_str: str): - """ - 示例后置工具:根据日期获取天气信息 - - Args: - date_str: 日期字符串,格式为 "MM-DD"。 - 例如: "11-29" 表示 11 月 29 日。 - - Returns: - 一个描述该日期天气状况的中文字符串。 - 示例返回值: "明天天气晴朗,可以放心出行。" - """ - # 这里只是 demo,真实逻辑中请根据 date_str 查询实际天气 - return "明天天气晴朗,可以放心出行。" - - # ---------------------------------------------------------------------- - - # ============================================================== - # NODES - # ============================================================== - async def step1(state: xxState) -> xxState: - """ - 示例节点 1: 使用新的策略模式创建和执行 Agent - - 新版 Agent 创建方式推荐使用 `create_agent` 配合配置对象 (Config) - 或使用便捷函数 `create_simple_agent`, `create_react_agent` 等。 - - 执行模式说明: - - SimpleConfig: 简单模式,单次 LLM 调用 - - ReactConfig: ReAct 模式,带验证和重试的循环 - - GraphConfig: 图模式,用于执行带工具的子图 (LangGraph) - - VLMConfig: 视觉语言模型模式 - """ - - # ==================== 方式一:使用便捷函数 (推荐) ==================== - - # 示例 1: 创建一个简单的 Agent (SimpleConfig) - # agent = create_simple_agent( - # name="your_agent_name", # 替换为实际 agent 名称 - # model_name="gpt-4", # 模型名称,如 "gpt-4", "gpt-4-turbo" - # temperature=0.7, # 采样温度 (0.0-1.0),越高越随机 - # max_tokens=4096, # 最大生成token数 - # parser_type="json", # 解析器类型: "json", "xml", "text" - # ignore_history=True, # 是否忽略历史消息 - # ) - - # 示例 2: 创建一个 ReAct Agent,具备自动修正能力 (ReactConfig) - # agent = create_react_agent( - # name="your_agent_name", # 替换为实际 agent 名称 - # model_name="gpt-4-turbo", - # temperature=0.1, - # max_retries=3, # ReAct 模式下的最大重试次数 - # parser_type="json", # 定义输出解析器 - # # validators=[], # 可选:自定义验证器列表 - # ) - - # 示例 3: 创建一个图模式 Agent (GraphConfig) - # agent = create_graph_agent( - # name="your_agent_name", - # model_name="gpt-4", - # temperature=0.2, - # tool_mode="auto", # 工具调用模式: "auto", "none", "required" - # ) - - # 示例 4: 创建一个 VLM Agent 用于处理图像 (VLMConfig) - # agent = create_vlm_agent( - # name="your_agent_name", - # model_name="gpt-4-vision-preview", # 视觉模型 - # temperature=0.1, - # vlm_mode="understanding", # 视觉模式: 'understanding', 'generation', 'edit' - # image_detail="high", # 图像细节: 'low', 'high', 'auto' - # max_image_size=(2048, 2048), # 最大图像尺寸 - # # additional_params={}, # 额外VLM参数 - # ) - - # ==================== 方式二:使用 create_agent 和配置对象 ==================== - - # 示例 5: 使用 SimpleConfig 创建简单模式 Agent - # config = SimpleConfig( - # model_name="gpt-4", - # temperature=0.7, - # max_tokens=4096, - # parser_type="json", - # ignore_history=True, - # ) - # agent = create_agent(name="your_agent_name", config=config) - - # 示例 6: 使用 ReactConfig 创建 ReAct 模式 Agent - # config = ReactConfig( - # model_name="gpt-4-turbo", - # temperature=0.1, - # max_retries=3, - # parser_type="json", - # # validators=[custom_validator], # 可选:自定义验证器 - # ) - # agent = create_agent(name="your_agent_name", config=config) - - # 示例 7: 使用 GraphConfig 创建图模式 Agent - # config = GraphConfig( - # model_name="gpt-4", - # temperature=0.2, - # tool_mode="auto", # 工具调用模式 - # ) - # agent = create_agent(name="your_agent_name", config=config) - - # 示例 8: 使用 VLMConfig 创建视觉模式 Agent - # config = VLMConfig( - # model_name="gpt-4-vision-preview", - # temperature=0.1, - # vlm_mode="understanding", - # image_detail="high", - # max_image_size=(2048, 2048), - # additional_params={"max_tokens": 4096}, - # ) - # agent = create_agent(name="your_agent_name", config=config) - - # ==================== 方式三:使用工具管理器 ==================== - - # 示例 9: 使用自定义工具管理器 - # from dataflow_agent.toolkits.tool_manager import ToolManager - # tool_manager = ToolManager() # 或使用 get_tool_manager() - # agent = create_simple_agent( - # name="your_agent_name", - # tool_manager=tool_manager, - # model_name="gpt-4", - # temperature=0.5, - # ) - - # ==================== 实际使用示例 ==================== - - # 实际使用:创建一个简单的代码审查 Agent - agent = create_simple_agent( - name="code_reviewer", # 替换为已注册的 agent 名称 - model_name="gpt-4-turbo", - temperature=0.1, - max_tokens=2048, - parser_type="json", - ) - - # 或者创建一个带工具的文档分析 Agent - # agent = create_react_agent( - # name="document_analyzer", - # model_name="gpt-4", - # temperature=0.3, - # max_retries=2, - # parser_type="json", - # ) - - # -------------------------------------------------------------------- - - # 执行 agent (注意:新版 execute 不再需要 use_agent 参数) - # 你可以在 execute 中传入 kwargs 来覆盖 pre_tool_results - state = await agent.execute(state=state) - - # 可选:处理执行结果 - agent_result = state.agent_results.get(agent.role_name, {}) - log.info(f"Agent {agent.role_name} 执行结果: {agent_result}") - - return state - - async def step2(state: xxState) -> xxState: - """ - 示例节点 2: 处理agent执行结果 - - Args: - state: 主状态对象 - """ - # TODO: 替换为真正的业务逻辑 - state.agent_results["step2"] = {"msg": "hello step2"} - - # 示例:从 step1 的结果中提取数据 - # if "code_reviewer" in state.agent_results: - # review_result = state.agent_results["code_reviewer"] - # # 处理审查结果... - - return state - - # ============================================================== - # 注册 nodes / edges - # ============================================================== - nodes = { - "step1": step1, - "step2": step2, - '_end_': lambda state: state, # 终止节点 - } - - # ------------------------------------------------------------------ - # EDGES (从节点 A 指向节点 B) - # ------------------------------------------------------------------ - edges = [ - ("step1", "step2"), - ("step2", "_end_"), # 指向终止节点 - ] - - builder.add_nodes(nodes).add_edges(edges) - return builder diff --git a/dataflow_agent/toolkits/__init__.py b/dataflow_agent/toolkits/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/basetool/__init__.py b/dataflow_agent/toolkits/basetool/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/basetool/file_tools.py b/dataflow_agent/toolkits/basetool/file_tools.py deleted file mode 100644 index 6db3efe..0000000 --- a/dataflow_agent/toolkits/basetool/file_tools.py +++ /dev/null @@ -1,387 +0,0 @@ -from __future__ import annotations -import asyncio -import importlib -import inspect -import sys -import os -import traceback -from pydantic import BaseModel -import httpx -import json -import uuid -from typing import List, Dict, Sequence, Any, Union, Optional, Iterable, Mapping, Set, Callable -from pathlib import Path - -from functools import lru_cache -import yaml -# from clickhouse_connect import get_client -import subprocess -from collections import defaultdict, deque -from dataflow.utils.storage import FileStorage -from dataflow_agent.logger import get_logger -logger = get_logger() -from dataflow_agent.storage.storage_service import SampleFileStorage -from dataflow_agent.state import DFState,DFRequest -import re - -MAX_JSONL_LINES = 50 -DATA_DIR = Path("./data/knowledgebase") # Local data storage directory - -def local_tool_for_sample( - state: DFRequest, - sample_size: int = 10, - use_file_sys: int = 1, - cache_type: str = "jsonl", - only_keys: bool = False, -) -> Dict[str, Any]: - from collections import Counter - """ - Sample, classify, and compute statistics on sample data. - - Args: - state: Request object containing file information - sample_size: Number of samples to retrieve. - use_file_sys: Whether to use file system storage (1) or not (0). - cache_type: Storage cache type ("jsonl" by default). - only_keys: If True, return only the keys found in samples - - Returns: - A dictionary with overall statistics and sample details. - """ - def judge_type(sample: Dict[str, Any]) -> str: - """ - Determine and return the type of a sample. - - Args: - sample: The sample to be judged. - - Returns: - The type of the sample as a string. - """ - if not isinstance(sample, dict): - return "Other" - if "conversations" in sample and isinstance(sample["conversations"], list): - ok = True - for msg in sample["conversations"]: - if not ( - (isinstance(msg, dict) and "role" in msg and "content" in msg) or - (isinstance(msg, dict) and "from" in msg and "value" in msg) - ): - ok = False - break - if ok: - return "SFT Multi-Round" - if "instruction" in sample and "output" in sample: - if isinstance(sample["instruction"], str) and isinstance(sample["output"], str): - if "input" not in sample or sample["input"] is None or isinstance(sample["input"], str): - return "SFT Single" - pt_keys = {"text", "content", "sentence"} - if len(sample) == 1: - (k, v), = sample.items() - if k in pt_keys and isinstance(v, str): - return "PT" - return "Other" - - # Storage selection - if use_file_sys: - from dataflow_agent.storage.storage_service import SampleFileStorage - - # 创建存储实例 - storage = SampleFileStorage( - first_entry_file_name=state.json_file, - cache_type=cache_type # 使用传入的cache_type参数 - ) - storage.step() - - logger.debug(f"------------Before Sampling--------------------") - - # 获取总数 - total = storage.count() - - # 使用新的rsample方法进行采样 - samples, actual_sample_size = storage.rsample( - mode="manual", - k=sample_size - ) - - logger.debug(f"------------After Sampling--------------------") - logger.debug(f"Requested: {sample_size}, Actual: {actual_sample_size}, Total: {total}") - - else: - # 如果不使用文件系统,返回空结果或者抛出异常 - logger.warning("Non-file system storage not implemented in new version") - samples = [] - total = 0 - - # 如果只需要keys,获取字段信息 - if only_keys: - if use_file_sys and storage: - # 使用新的get_fields方法 - key_set = set(storage.get_fields()) - # 如果需要从样本中获取更完整的keys - for sample in samples: - if isinstance(sample, dict): - key_set.update(sample.keys()) - return sorted(key_set) - else: - # 从样本中收集keys - key_set = set().union(*(s.keys() for s in samples if isinstance(s, dict))) - return sorted(key_set) - - # 分类样本并计算统计信息 - type_list = [judge_type(s) for s in samples] - counter = Counter(type_list) - - # 计算分布(基于实际样本数而不是总数) - sample_count = len(samples) - dist = { - t: {"count": c, "ratio": round(c / sample_count, 4) if sample_count > 0 else 0.0} - for t, c in counter.items() - } - - # 收集所有keys - key_set = set().union(*(s.keys() for s in samples if isinstance(s, dict))) - - stats = { - "total": total, - "sample_size": sample_count, - "stateed_size": sample_size, - "distribution": dist, - "samples": samples, - "available_keys": sorted(key_set) - } - - logger.debug(f"-------Data Statistics-------\n {stats}") - return stats - - -def local_tool_for_get_categories(): - """ - 返回 OPERATOR_REGISTRY 中实际注册的 operator 分类列表(如 agentic_rag, chemistry, ...)。 - """ - try: - from dataflow.utils.registry import OPERATOR_REGISTRY - if hasattr(OPERATOR_REGISTRY, '_init_loaders'): - OPERATOR_REGISTRY._init_loaders() - if hasattr(OPERATOR_REGISTRY, '_get_all'): - OPERATOR_REGISTRY._get_all() - categories = set() - for name, cls in OPERATOR_REGISTRY: - if hasattr(cls, '__module__'): - parts = cls.__module__.split('.') - if len(parts) >= 3 and parts[0] == 'dataflow' and parts[1] == 'operators': - categories.add(parts[2]) - return sorted(categories) - - except Exception as e: - return [] - -# ================================================================修改python文件的某行代码 - - -def change_pycode_lines( - file_path: Union[str, Path], - patches: Dict[int, str], - *, - encoding: str = "utf-8", - inherit_indent: bool = True, - make_backup: bool = True, - backup_suffix: str = ".bak", - write_back: bool = True, -) -> List[str]: - """ - 根据行号-文本映射修改 Python 文件,并可自动继承原行缩进。 - """ - path = Path(file_path).expanduser().resolve() - if not path.is_file(): - raise FileNotFoundError(path) - - # 读取原文件 - lines = path.read_text(encoding=encoding).splitlines(keepends=True) - - # 先备份 - if write_back and make_backup: - path.with_suffix(path.suffix + backup_suffix).write_text( - "".join(lines), encoding=encoding - ) - - max_line = len(lines) - invalid = [ln for ln in patches if ln < 1 or ln > max_line] - if invalid: - raise IndexError(f"行号越界 1-{max_line}: {invalid}") - - for ln, new_body in patches.items(): - old_line = lines[ln - 1] - - # 行尾换行符 - eol = old_line[len(old_line.rstrip("\r\n")) :] - - # 缩进 - indent = "" - if inherit_indent and not new_body.startswith((" ", "\t")): - indent = re.match(r"[ \t]*", old_line).group(0) - - newline = eol if eol else "\n" # 关键修复 - lines[ln - 1] = f"{indent}{new_body}{newline}" - - if write_back: - path.write_text("".join(lines), encoding=encoding) - - return lines - - -# =======================================================获取辅助源码 -from dataflow.utils.registry import OPERATOR_REGISTRY -def _extract_module_source(op_name: str) -> str: - """ - 根据 OPERATOR_REGISTRY 中登记的 `op_name` - 返回其**完整模块**源码字符串;提取失败时返回占位提示。 - - 1. 通过 `OPERATOR_REGISTRY.get()` 拉起 LazyLoader,拿到类对象; - 2. 借助 `cls.__module__` 取得模块名,再用 `importlib` / `inspect` - 提取源码; - 3. 若出现异常,记录日志并返回占位串,保证调用方逻辑不被打断。 - """ - logger = get_logger() - - try: - # ① 拉取并触发懒加载 - cls = OPERATOR_REGISTRY.get(op_name) - - # ② 确保模块已导入 - mod = importlib.import_module(cls.__module__) - - # ③ 提取源码 - return inspect.getsource(mod) - - except Exception as e: - logger.warning(f"无法提取 {op_name} 的源码: {e}") - logger.debug(traceback.format_exc()) - return "没有找到任务源代码,直接返回 other_info 即可;" - - -def get_otherinfo_code(op_names: List[str]) -> Dict[str, str]: - """ - 批量获取多个 operator 对应的源码字符串。 - - :param op_names: 由 operator 名称组成的列表 - :return: {op_name: source_code} - """ - return {name: _extract_module_source(name) for name in op_names} - - -# =============================高亮 -def flashy(msg: str, *, color: str = "yellow") -> str: - """ - 返回带 ANSI 颜色的字符串;调试场合用。 - 支持的 color: red / green / yellow / blue / magenta / cyan / white - """ - colors = { - "black": 30, "red": 31, "green": 32, "yellow": 33, - "blue": 34, "magenta": 35, "cyan": 36, "white": 37, - } - code = colors.get(color, 33) - return f"\033[{code}m{msg}\033[0m" - -if __name__ == "__main__": - # 简单测试 local_tool_for_sample - from dataflow_agent.state import DFRequest - state = DFRequest( - language="zh", - json_file=f"{DataFlowPath.get_dataflow_dir().parent}/dataflow/example/DataflowAgent/mq_test_data.jsonl" - ) - print(local_tool_for_sample(state,sample_size=2)) - # from dataflow.utils.registry import OPERATOR_REGISTRY - - # print("="*50) - # print("OPERATOR_REGISTRY 调试") - # print("="*50) - - # # 1. 查看初始状态 - # print(f"初始 _obj_map 长度: {len(OPERATOR_REGISTRY._obj_map)}") - # print(f"初始 loader_map: {OPERATOR_REGISTRY.loader_map}") - # print(f"初始 loader_map values: {list(OPERATOR_REGISTRY.loader_map.values())}") - - # # 2. 手动触发加载器初始化 - # print("\n手动触发 _init_loaders()...") - # try: - # OPERATOR_REGISTRY._init_loaders() - # print("✓ _init_loaders() 执行成功") - # print(f"加载后 loader_map values: {[type(v).__name__ for v in OPERATOR_REGISTRY.loader_map.values()]}") - # except Exception as e: - # print(f"✗ _init_loaders() 失败: {e}") - - # # 3. 手动触发加载所有操作符 - # print("\n手动触发 _get_all()...") - # try: - # OPERATOR_REGISTRY._get_all() - # print("✓ _get_all() 执行成功") - # print(f"_get_all() 后 _obj_map 长度: {len(OPERATOR_REGISTRY._obj_map)}") - # except Exception as e: - # print(f"✗ _get_all() 失败: {e}") - # # 如果 _get_all() 失败,尝试手动调用每个loader的 _import_all() - # print("尝试手动加载每个模块...") - # for module_name, loader in OPERATOR_REGISTRY.loader_map.items(): - # if loader is not None: - # try: - # print(f" 加载 {module_name}...") - # loader._import_all() - # print(f" ✓ {module_name} 加载成功") - # except Exception as me: - # print(f" ✗ {module_name} 加载失败: {me}") - # print(f"手动加载后 _obj_map 长度: {len(OPERATOR_REGISTRY._obj_map)}") - - # # 4. 创建 _NAME2CLS 并显示结果 - # _NAME2CLS = {name: cls for name, cls in OPERATOR_REGISTRY} - - # print(f"\n最终结果:") - # print(f"_NAME2CLS 长度: {len(_NAME2CLS)}") - # print(f"_NAME2CLS keys (前20个): {list(_NAME2CLS.keys())[:20]}") - - # if _NAME2CLS: - # print(f"\n前10个操作符详情:") - # for i, (name, cls) in enumerate(_NAME2CLS.items()): - # if i >= 10: - # break - # print(f" {i+1}. {name}: {cls}") - # if hasattr(cls, '__module__'): - # print(f" 模块: {cls.__module__}") - # else: - # print("⚠️ _NAME2CLS 仍然为空!") - - # # 如果还是空的,尝试直接调用 get() 方法来触发单个操作符的加载 - # print("\n尝试使用 get() 方法触发加载...") - # # 随便试一个可能的操作符名称 - # test_names = ["TextOperator", "FileReader", "DataProcessor", "BaseOperator"] - # for test_name in test_names: - # try: - # cls = OPERATOR_REGISTRY.get(test_name) - # print(f"✓ 成功获取 {test_name}: {cls}") - # break - # except KeyError: - # print(f"✗ {test_name} 不存在") - # except Exception as e: - # print(f"✗ 获取 {test_name} 时出错: {e}") - - # # 再次检查 - # _NAME2CLS = {name: cls for name, cls in OPERATOR_REGISTRY} - # print(f"使用 get() 后 _NAME2CLS 长度: {len(_NAME2CLS)}") - - # # 5. 分析分类 - # if _NAME2CLS: - # print(f"\n按模块分类:") - # module_counts = {} - # for name, cls in _NAME2CLS.items(): - # if hasattr(cls, '__module__'): - # module_parts = cls.__module__.split('.') - # if len(module_parts) > 2 and module_parts[0] == 'dataflow' and module_parts[1] == 'operators': - # category = module_parts[2] # 取 dataflow.operators.xxx 中的 xxx - # else: - # category = cls.__module__ - - # module_counts[category] = module_counts.get(category, 0) + 1 - - # for category, count in sorted(module_counts.items()): - # print(f" {category}: {count} 个操作符") - - # print("="*50) \ No newline at end of file diff --git a/dataflow_agent/toolkits/basetool/llm_client.py b/dataflow_agent/toolkits/basetool/llm_client.py deleted file mode 100644 index 1f80974..0000000 --- a/dataflow_agent/toolkits/basetool/llm_client.py +++ /dev/null @@ -1,52 +0,0 @@ -import httpx -from typing import List, Dict, Any, Optional - -class LLMClient: - def __init__( - self, - base_url: str, - api_key: str, - model: str, - *, - timeout: int = 60, - org_id: Optional[str] = None, - ) -> None: - self.base_url = base_url.rstrip("/") - self.api_key = api_key - self.model = model - self.timeout = timeout - self.org_id = org_id - self._sync_client = httpx.Client(timeout=self.timeout, headers=self._headers()) - self._async_client = httpx.AsyncClient(timeout=self.timeout, headers=self._headers()) - - def _headers(self) -> Dict[str, str]: - h = {"Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json"} - if self.org_id: - h["OpenAI-Organization"] = self.org_id - return h - - def _prepare_payload(self, messages: List[Dict[str, str]], **extra: Any) -> Dict[str, Any]: - return { - "model": self.model, - "messages": messages, - **extra, - } - - def _parse_response(self, data: Dict[str, Any]) -> str: - try: - return data["choices"][0]["message"]["content"] - except (KeyError, IndexError, TypeError): - raise RuntimeError(f"Bad schema from API: {data}") from None - - def chat(self, messages: List[Dict[str, str]], **extra) -> str: - payload = self._prepare_payload(messages, **extra) - r = self._sync_client.post(self.base_url, json=payload) - r.raise_for_status() - return self._parse_response(r.json()) - - async def async_chat(self, messages: List[Dict[str, str]], **extra: Any) -> str: - payload = self._prepare_payload(messages, **extra) - r = await self._async_client.post(self.base_url, json=payload) - r.raise_for_status() - return self._parse_response(r.json()) \ No newline at end of file diff --git a/dataflow_agent/toolkits/dockertool/Makefile b/dataflow_agent/toolkits/dockertool/Makefile deleted file mode 100644 index e608f6b..0000000 --- a/dataflow_agent/toolkits/dockertool/Makefile +++ /dev/null @@ -1,29 +0,0 @@ -IMAGE ?= myorg/dataflow-py:3.11 - -.PHONY: build image run-thirdparty run-stdlib clean - -build: - python -m dataflow.dataflowagent.toolkits.dockertool.build_docker -i $(IMAGE) - -image: build - docker images | rg $(IMAGE) || true - -run-thirdparty: - python -m dataflow.dataflowagent.toolkits.dockertool.mini_docker \ - --file dataflow/dataflowagent/toolkits/dockertool/examples/hello_thirdparty.py \ - --requirements dataflow/dataflowagent/toolkits/dockertool/examples/requirements_thirdparty.txt \ - --image $(IMAGE) \ - --save dataflow/dataflowagent/toolkits/dockertool/artifacts/hello_thirdparty.tar \ - --timeout 180 - -run-stdlib: - python -m dataflow.dataflowagent.toolkits.dockertool.mini_docker \ - --file dataflow/dataflowagent/toolkits/dockertool/examples/hello_stdlib.py \ - --image $(IMAGE) \ - --save dataflow/dataflowagent/toolkits/dockertool/artifacts/hello_stdlib.tar \ - --timeout 120 - -clean: - rm -f dataflow/dataflowagent/toolkits/dockertool/artifacts/*.tar - rm -f dataflow/dataflowagent/toolkits/dockertool/artifacts/run_logs/* - diff --git a/dataflow_agent/toolkits/dockertool/README.md b/dataflow_agent/toolkits/dockertool/README.md deleted file mode 100644 index 03c8422..0000000 --- a/dataflow_agent/toolkits/dockertool/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Mini Docker Tool (DataFlow) - -该工具用于在最小隔离的 Docker 容器中运行指定的 Python 文件(可选安装 requirements),并在脚本成功完成后保存当前容器环境为镜像 tar,供后续复用。 - -## 快速开始 - -- 构建基础镜像(默认 `myorg/dataflow-py:3.11`): - - `python -m dataflow.dataflowagent.toolkits.dockertool.build_docker -i myorg/dataflow-py:3.11` - -- 运行示例并保存环境: - - 第三方依赖示例: - - `python -m dataflow.dataflowagent.toolkits.dockertool.mini_docker --file dataflow/dataflowagent/toolkits/dockertool/examples/hello_thirdparty.py --requirements dataflow/dataflowagent/toolkits/dockertool/examples/requirements_thirdparty.txt --image myorg/dataflow-py:3.11 --save dataflow/dataflowagent/toolkits/dockertool/artifacts/hello_thirdparty.tar --timeout 180` - - 仅标准库示例: - - `python -m dataflow.dataflowagent.toolkits.dockertool.mini_docker --file dataflow/dataflowagent/toolkits/dockertool/examples/hello_stdlib.py --image myorg/dataflow-py:3.11 --save dataflow/dataflowagent/toolkits/dockertool/artifacts/hello_stdlib.tar --timeout 120` - -运行成功后,会输出 JSON,包含:`success`、`return_code`、`stdout`、`stderr`、`image`、`tag`、`tar_path`、`log_path`。 - -## 设计要点 - -- 为避免依赖宿主机挂载的脚本在提交镜像后丢失,运行前会把脚本和 `requirements.txt` 复制到容器内部 `/app_job`。 -- 容器以“保活模式”启动,随后通过 `exec_run` 执行 `pip install`(如传入 `requirements`)与 `python -u /app_job/script.py`,以稳定获取退出码和输出。 -- 失败时不提交镜像;成功时 `commit` 并保存为 tar 到 `artifacts/`。 -- 本地会保存运行日志到 `artifacts/run_logs/`,用于排查问题。 - -## 目录说明 - -- `mini_docker.py`:核心工具入口,提供 CLI。 -- `build_docker.py`:构建基础镜像脚本(生成 `docker/Dockerfile` 与 `docker/requirements.txt`)。 -- `docker/`:镜像构建所需文件(自动生成)。 -- `examples/`:示例脚本与示例依赖清单。 -- `artifacts/`:输出的镜像 tar 与运行日志(git 忽略 tar 和日志内容)。 - - diff --git a/dataflow_agent/toolkits/dockertool/__init__.py b/dataflow_agent/toolkits/dockertool/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/dockertool/build_docker.py b/dataflow_agent/toolkits/dockertool/build_docker.py deleted file mode 100644 index 2e390f2..0000000 --- a/dataflow_agent/toolkits/dockertool/build_docker.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Python 版镜像构建脚本: -1. 写 requirements.txt / Dockerfile -2. docker build -3. (可选) docker push -""" - -import argparse -import subprocess -from pathlib import Path -import textwrap -import sys - -DEFAULT_REQS = """\ -pandas==2.2.2 -numpy==1.26.4 -sqlalchemy>=2.0,<3.0 - - - - -""" - -DOCKERFILE_TMPL = """\ -FROM python:{py_version}-slim AS runtime -RUN apt-get update && apt-get install -y --no-install-recommends \\ - build-essential gcc \\ - && rm -rf /var/lib/apt/lists/* -COPY requirements.txt /tmp/requirements.txt -RUN python -m pip install --no-cache-dir -r /tmp/requirements.txt -RUN useradd -ms /bin/bash appuser -USER appuser -WORKDIR /app -ENTRYPOINT ["python"] -""" - -def run(cmd: list[str], **kw): - print(">>", " ".join(cmd)) - subprocess.run(cmd, check=True, **kw) - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--image", default="myorg/dataflow-py:3.11", - help="镜像名:TAG") - parser.add_argument("--py", "--python", dest="py_version", - default="3.11", help="Python 主版本") - parser.add_argument("--push", action="store_true", - help="构建后 docker push") - opts = parser.parse_args() - - root = Path(__file__).resolve().parent - docker_dir = root / "docker" - docker_dir.mkdir(exist_ok=True) - - # 1. requirements.txt - (docker_dir / "requirements.txt").write_text(DEFAULT_REQS, encoding="utf-8") - print("✓ requirements.txt written") - - # 2. Dockerfile - dockerfile_str = DOCKERFILE_TMPL.format(py_version=opts.py_version) - (docker_dir / "Dockerfile").write_text(dockerfile_str, encoding="utf-8") - print("✓ Dockerfile written") - - # 3. build - run(["docker", "build", "-t", opts.image, str(docker_dir)]) - - # 4. push - if opts.push: - run(["docker", "push", opts.image]) - - print(f"\n 完成!请把 DOCKER_IMAGE 改为 '{opts.image}'\n") - -if __name__ == "__main__": - try: - main() - except subprocess.CalledProcessError as e: - sys.exit(e.returncode) \ No newline at end of file diff --git a/dataflow_agent/toolkits/dockertool/docker/Dockerfile b/dataflow_agent/toolkits/dockertool/docker/Dockerfile deleted file mode 100644 index 92ee6b0..0000000 --- a/dataflow_agent/toolkits/dockertool/docker/Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -FROM python:3.11-slim AS runtime -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential gcc \ - && rm -rf /var/lib/apt/lists/* -COPY requirements.txt /tmp/requirements.txt -RUN python -m pip install --no-cache-dir -r /tmp/requirements.txt -RUN useradd -ms /bin/bash appuser -USER appuser -WORKDIR /app -ENTRYPOINT ["python"] diff --git a/dataflow_agent/toolkits/dockertool/docker/requirements.txt b/dataflow_agent/toolkits/dockertool/docker/requirements.txt deleted file mode 100644 index 0ddea16..0000000 --- a/dataflow_agent/toolkits/dockertool/docker/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -pandas==2.2.2 -numpy==1.26.4 -sqlalchemy>=2.0,<3.0 - - - - diff --git a/dataflow_agent/toolkits/dockertool/examples/hello_stdlib.py b/dataflow_agent/toolkits/dockertool/examples/hello_stdlib.py deleted file mode 100644 index e25a0e5..0000000 --- a/dataflow_agent/toolkits/dockertool/examples/hello_stdlib.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -一个不依赖第三方库的简单示例脚本,用于测试 minidocker 运行与保存镜像。 - -运行后应打印几行信息并正常退出(退出码 0)。 -""" - -import sys -import platform -from datetime import datetime - - -def main(): - print("[hello_stdlib] 启动时间:", datetime.now().isoformat()) - print("[hello_stdlib] Python:", sys.version.replace("\n", " ")) - print("[hello_stdlib] 平台:", platform.platform()) - - # 做一个简单计算 - total = sum(i * i for i in range(10)) - print("[hello_stdlib] 计算结果 sum(i*i, i=0..9):", total) - - print("[hello_stdlib] 脚本运行完成,准备退出。") - - -if __name__ == "__main__": - main() - diff --git a/dataflow_agent/toolkits/dockertool/examples/hello_thirdparty.py b/dataflow_agent/toolkits/dockertool/examples/hello_thirdparty.py deleted file mode 100644 index afcbe7f..0000000 --- a/dataflow_agent/toolkits/dockertool/examples/hello_thirdparty.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -一个依赖第三方库的示例脚本:使用 pandas 和 numpy。 -运行后应打印 pandas/numpy 版本、简单数据处理结果。 -""" - -import sys -import platform -from datetime import datetime - -import numpy as np -import pandas as pd - - -def main(): - print("[hello_thirdparty] 启动时间:", datetime.now().isoformat()) - print("[hello_thirdparty] Python:", sys.version.replace("\n", " ")) - print("[hello_thirdparty] 平台:", platform.platform()) - print("[hello_thirdparty] numpy:", np.__version__) - print("[hello_thirdparty] pandas:", pd.__version__) - - # 构造一个简单的 DataFrame 并做一次计算 - df = pd.DataFrame({"a": np.arange(5), "b": np.arange(5) ** 2}) - df["c"] = df["a"] + df["b"] - print("[hello_thirdparty] DataFrame head:\n", df.head().to_string(index=False)) - - print("[hello_thirdparty] 脚本运行完成,准备退出。") - abc - -if __name__ == "__main__": - main() - diff --git a/dataflow_agent/toolkits/dockertool/examples/requirements_thirdparty.txt b/dataflow_agent/toolkits/dockertool/examples/requirements_thirdparty.txt deleted file mode 100644 index d1c49cc..0000000 --- a/dataflow_agent/toolkits/dockertool/examples/requirements_thirdparty.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy==1.26.4 -pandas==2.2.2 diff --git a/dataflow_agent/toolkits/dockertool/mini_docker.py b/dataflow_agent/toolkits/dockertool/mini_docker.py deleted file mode 100644 index 9409379..0000000 --- a/dataflow_agent/toolkits/dockertool/mini_docker.py +++ /dev/null @@ -1,415 +0,0 @@ -from __future__ import annotations -import asyncio -import argparse -import io -import json -import os -import tarfile -import sys -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Optional - -import docker # pip install docker -from docker.errors import DockerException, NotFound, APIError, ContainerError - -# 使用 build_docker.py 生成的自定义镜像 -DOCKER_IMAGE = "myorg/dataflow-py:3.11" -RUN_TIMEOUT = 120 # s - - -def _run_in_container_sync( - file_path: Path, - image: str = DOCKER_IMAGE, - timeout: int = RUN_TIMEOUT, -) -> Dict[str, Any]: - """ - 旧版:只读挂载、禁网运行(保留兼容)。实际请使用 run_file_in_minidocker。 - """ - client = docker.from_env() - - try: - client.images.get(image) - except NotFound: - client.close() - return { - "success": False, - "return_code": -1, - "stdout": "", - "stderr": ( - f"image not found: {image}. 请先在宿主机预置镜像(例如 docker load)," - f"或用 build_docker.py 构建并保存/加载。" - ), - "file_path": str(file_path), - } - - container = None - try: - container = client.containers.run( - image=image, - command=["python", "/app/script.py"], - volumes={str(file_path): {"bind": "/app/script.py", "mode": "ro"}}, - detach=True, - stdout=True, - stderr=True, - network_disabled=True, - read_only=True, - mem_limit="512m", - pids_limit=128, - cpu_quota=100000, - ) - - exit_status = container.wait(timeout=timeout) - code = exit_status.get("StatusCode", 137) - result = { - "success": code == 0, - "return_code": code, - "stdout": container.logs(stdout=True, stderr=False).decode(), - "stderr": container.logs(stdout=False, stderr=True).decode(), - "file_path": str(file_path), - } - except DockerException as e: - result = { - "success": False, - "return_code": -1, - "stdout": "", - "stderr": f"Docker error: {e}", - "file_path": str(file_path), - } - finally: - if container: - try: - container.remove(force=True) - except Exception: - pass - client.close() - - return result - - -def _make_tar_bytes(files: Dict[str, bytes]) -> bytes: - """ - 将内存文件打包为 tar 字节流,用于 put_archive。 - """ - buf = io.BytesIO() - with tarfile.open(fileobj=buf, mode="w") as tf: - for rel_path, data in files.items(): - ti = tarfile.TarInfo(name=rel_path) - ti.size = len(data) - ti.mtime = int(datetime.now().timestamp()) - tf.addfile(ti, io.BytesIO(data)) - buf.seek(0) - return buf.read() - - -def run_file_in_minidocker( - file_path: str | Path, - image: str = DOCKER_IMAGE, - workdir: Optional[str] = None, - requirements: Optional[str] = None, - save_tar_path: Optional[str] = None, - enable_network: bool = True, - timeout: int = RUN_TIMEOUT, -) -> Dict[str, Any]: - """ - 将 Python 文件复制进容器,联网安装依赖(可选),运行并在成功后提交镜像并保存为 tar。 - 仅保存容器镜像 tar,不额外保存脚本输出文件。 - """ - src = Path(file_path).expanduser().resolve() - if not src.exists(): - return {"success": False, "return_code": -1, "stdout": "", "stderr": f"file not found: {src}"} - - req_path: Optional[Path] = None - if requirements: - rp = Path(requirements).expanduser().resolve() - if not rp.exists(): - return {"success": False, "return_code": -1, "stdout": "", "stderr": f"requirements not found: {rp}"} - req_path = rp - - client = docker.from_env() - - try: - client.images.get(image) - except NotFound: - client.close() - return { - "success": False, - "return_code": -1, - "stdout": "", - "stderr": ( - f"image not found: {image}. 请先在宿主机预置镜像(例如 docker load)," - f"或用 build_docker.py 构建并保存/加载。" - ), - } - - # 可选:挂载宿主机工作目录到 /workspace(便于用户脚本访问外部数据)。脚本本身与 requirements 会复制到容器内部。 - volumes = {} - if workdir: - volumes[os.path.abspath(workdir)] = {"bind": "/workspace", "mode": "rw"} - - # 通过覆盖 entrypoint 为 sh -lc,一次性执行:pip 安装(可选)+ 运行脚本。 - # 简化为直接依赖容器 stdout/stderr(docker 日志)来获取输出,并在宿主机保存一份副本。 - run_ts = datetime.now().strftime("%Y%m%d-%H%M%S") - out_filename = f"{Path(file_path).stem}_{run_ts}.log" - cmd_parts = [] - if req_path: - cmd_parts.append("python -m pip install --user -r /app/requirements.txt") - # 直接运行脚本,输出由 docker 日志采集 - cmd_parts.append("python -u /app/script.py") - cmd_str = " && ".join(cmd_parts) - - # 挂载仅保留用户指定的工作目录,其它文件通过 put_archive 复制进容器,确保 commit 后可用 - mount_volumes = {} - if volumes: - mount_volumes.update(volumes) - - artifacts_dir = Path(__file__).resolve().parent / "artifacts" - run_logs_dir = artifacts_dir / "run_logs" - run_logs_dir.mkdir(parents=True, exist_ok=True) - - # 改为 run(detach) + exec_run,避免 shell 行为影响退出码和输出 - # 尝试多种保活策略,确保容器处于 running - container = None - start_error: Optional[Exception] = None - strategies = [ - {"entrypoint": ["tail", "-f", "/dev/null"], "command": None}, - {"entrypoint": ["/bin/sh", "-c"], "command": "sleep infinity"}, - {"entrypoint": ["python"], "command": ["-c", "import time; time.sleep(10**9)"]}, - ] - for strat in strategies: - try: - container = client.containers.run( - image=image, - entrypoint=strat["entrypoint"], - command=strat["command"], - detach=True, - volumes=mount_volumes, - network_disabled=not enable_network, - mem_limit="2g", - pids_limit=512, - cpu_quota=200000, - working_dir="/app", - environment={"PYTHONUNBUFFERED": "1"}, - ) - start_error = None - break - except (APIError, ContainerError, DockerException) as e: - start_error = e - continue - if container is None: - client.close() - return {"success": False, "return_code": -1, "stdout": "", "stderr": f"container start error: {start_error}"} - - # 等待容器进入 running 状态,避免后续 exec_run 409 错误 - try: - waited = 0.0 - max_wait = float(timeout) - while waited < max_wait: - container.reload() - status = getattr(container, "status", None) - if status == "running": - break - if status in {"exited", "dead"}: - logs = "" - try: - logs = container.logs().decode() - except Exception: - pass - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": -1, "stdout": "", "stderr": f"container not running, status={status}. logs: {logs}"} - await_time = 0.2 - import time as _t - _t.sleep(await_time) - waited += await_time - # 最后再检查一次 - container.reload() - if getattr(container, "status", None) != "running": - client.close() - return {"success": False, "return_code": -1, "stdout": "", "stderr": f"container not running after wait, status={getattr(container, 'status', None)}"} - except Exception as e: - logs = "" - try: - logs = container.logs().decode() - except Exception: - pass - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": -1, "stdout": "", "stderr": f"container wait error: {e}. logs: {logs}"} - - combined_logs: list[str] = [] - # 创建工作目录并复制脚本/依赖到容器内部,确保 commit 后仍可用 - job_dir = "/app_job" - try: - _code, _out = container.exec_run(["/bin/sh", "-c", f"mkdir -p {job_dir}"]) - # 打包文件为 tar 并上传 - files: Dict[str, bytes] = {} - files["script.py"] = src.read_bytes() - if req_path: - files["requirements.txt"] = req_path.read_bytes() - tar_bytes = _make_tar_bytes(files) - container.put_archive(job_dir, tar_bytes) - except Exception as e: - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": -1, "stdout": "", "stderr": f"put files error: {e}"} - # 先安装依赖(如有) - if req_path: - try: - pip_exit, pip_out = container.exec_run( - ["python", "-m", "pip", "install", "--user", "-r", f"{job_dir}/requirements.txt"], - stream=False, - ) - try: - combined_logs.append(pip_out.decode()) - except Exception: - combined_logs.append(pip_out.decode(errors="ignore")) - if pip_exit not in (0, None): - # 依赖安装失败 - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": pip_exit or -1, "stdout": "".join(combined_logs), "stderr": "pip install failed"} - except Exception as e: - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": -1, "stdout": "".join(combined_logs), "stderr": f"pip exec error: {e}"} - - # 运行脚本 - try: - code, run_out = container.exec_run(["python", "-u", f"{job_dir}/script.py"], stream=False) - try: - combined_logs.append(run_out.decode()) - except Exception: - combined_logs.append(run_out.decode(errors="ignore")) - if code is None: - code = 137 - except Exception as e: - out = "".join(combined_logs) - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": -1, "stdout": out, "stderr": f"run exec error: {e}"} - - stdout_s = "".join(combined_logs) - stderr_s = "" # exec_run 已合并输出 - if code != 0 or ("Traceback (most recent call last)" in stdout_s): - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": code, "stdout": stdout_s, "stderr": stderr_s} - - # 成功提交镜像并保存 tar - repo = f"minidocker/{src.stem}" - tag = datetime.now().strftime("%Y%m%d-%H%M%S") - try: - image_obj = container.commit(repository=repo, tag=tag) - except DockerException as e: - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": -1, "stdout": stdout_s, "stderr": f"commit error: {e}"} - - artifacts_dir = Path(__file__).resolve().parent / "artifacts" - artifacts_dir.mkdir(exist_ok=True) - # 将运行日志保存到本地,便于排查问题 - run_logs_dir = artifacts_dir / "run_logs" - run_logs_dir.mkdir(parents=True, exist_ok=True) - log_path = run_logs_dir / out_filename - try: - with open(log_path, "w", encoding="utf-8") as f: - f.write(stdout_s) - if stderr_s: - f.write("\n[stderr]\n") - f.write(stderr_s) - except Exception: - # 如果保存日志失败,不影响后续保存 tar - pass - - tar_path = Path(save_tar_path).expanduser().resolve() if save_tar_path else artifacts_dir / f"{repo.replace('/', '_')}_{tag}.tar" - try: - with open(tar_path, "wb") as f: - for chunk in image_obj.save(named=True): - f.write(chunk) - except Exception as e: - try: - container.remove(force=True) - except Exception: - pass - client.close() - return {"success": False, "return_code": -1, "stdout": stdout_s, "stderr": f"save tar error: {e}"} - - try: - container.remove(force=True) - except Exception: - pass - client.close() - return { - "success": True, - "return_code": 0, - "stdout": stdout_s, - "stderr": stderr_s, - "image": repo, - "tag": tag, - "tar_path": str(tar_path), - "log_path": str(log_path), - } - - -async def _run_py_in_docker( - file_path: Path, - image: str = DOCKER_IMAGE, - timeout: int = RUN_TIMEOUT, -) -> Dict[str, Any]: - """ - 异步封装(旧版),建议换用 run_file_in_minidocker。 - """ - return await asyncio.to_thread(_run_in_container_sync, file_path, image, timeout) - - -def _cli_main(argv: list[str] | None = None) -> int: - parser = argparse.ArgumentParser(description="在 minidocker 中运行 Python 文件并保存镜像 tar") - parser.add_argument("--file", required=True, help="要运行的 Python 脚本路径") - parser.add_argument("--image", default=DOCKER_IMAGE, help="基础镜像名:TAG") - parser.add_argument("--requirements", default=None, help="requirements.txt 路径(可选)") - parser.add_argument("--save", dest="save_tar", default=None, help="保存 tar 的目标路径(可选)") - parser.add_argument("--workdir", default=None, help="宿主机工作目录,映射到 /workspace(可选)") - parser.add_argument("--timeout", type=int, default=RUN_TIMEOUT, help="运行超时(秒)") - parser.add_argument("--enable-network", action="store_true", default=True, help="启用容器网络(默认开启)") - args = parser.parse_args(argv) - - res = run_file_in_minidocker( - file_path=args.file, - image=args.image, - workdir=args.workdir, - requirements=args.requirements, - save_tar_path=args.save_tar, - enable_network=args.enable_network, - timeout=args.timeout, - ) - print(json.dumps(res, ensure_ascii=False, indent=2)) - return 0 if res.get("success") else 1 - - -if __name__ == "__main__": - sys.exit(_cli_main()) diff --git a/dataflow_agent/toolkits/drawio_tools.py b/dataflow_agent/toolkits/drawio_tools.py deleted file mode 100644 index 00c3ae2..0000000 --- a/dataflow_agent/toolkits/drawio_tools.py +++ /dev/null @@ -1,586 +0,0 @@ -""" -Draw.io XML 工具函数 -提供 XML 包装、提取、验证和编辑功能 -""" -import os -import re -import shutil -import subprocess -import tempfile -import xml.etree.ElementTree as ET -from pathlib import Path -from typing import List, Dict, Any, Optional, Tuple, Iterable - - -# draw.io XML 模板 -DRAWIO_WRAPPER_TEMPLATE = ''' - - - - - - -{cells} - - - -''' - - -def wrap_xml(cells_xml: str, modified: str = "", page_width: int | float = 850, page_height: int | float = 1100) -> str: - """ - 将 mxCell 元素包装为完整的 draw.io XML - - Args: - cells_xml: 仅包含 mxCell 元素的 XML 字符串 - modified: 修改时间戳 - - Returns: - 完整的 draw.io XML 文件内容 - """ - from datetime import datetime - if not modified: - modified = datetime.now().isoformat() - - # 确保每行有适当的缩进 - lines = cells_xml.strip().split('\n') - indented_lines = [' ' + line.strip() for line in lines if line.strip()] - indented_cells = '\n'.join(indented_lines) - - return DRAWIO_WRAPPER_TEMPLATE.format( - modified=modified, - cells=indented_cells, - page_width=int(round(page_width)), - page_height=int(round(page_height)), - ) - - -def extract_cells(full_xml: str) -> str: - """ - 从完整的 draw.io XML 中提取 mxCell 元素 - - Args: - full_xml: 完整的 draw.io XML - - Returns: - 仅包含 mxCell 元素的 XML 字符串(不含 id="0" 和 id="1") - """ - try: - root = ET.fromstring(full_xml) - cells = [] - - # 查找所有 mxCell 元素 - for cell in root.iter('mxCell'): - cell_id = cell.get('id', '') - # 跳过根单元格 - if cell_id in ('0', '1'): - continue - cells.append(ET.tostring(cell, encoding='unicode')) - - return '\n'.join(cells) - except ET.ParseError as e: - # 如果解析失败,尝试用正则提取 - pattern = r']*id="(?!0|1")[^"]*"[^>]*>.*?|]*id="(?!0|1")[^"]*"[^/]*/>' - matches = re.findall(pattern, full_xml, re.DOTALL) - return '\n'.join(matches) - - -def validate_xml(cells_xml: str) -> Tuple[bool, List[str]]: - """ - 验证 mxCell XML 的结构 - - Args: - cells_xml: mxCell 元素的 XML 字符串 - - Returns: - (is_valid, errors) 元组 - """ - errors = [] - - # 检查是否包含禁止的包装标签 - forbidden_tags = ['', ']*id=["\']0["\']', cells_xml): - errors.append("包含禁止的根单元格 id='0'") - if re.search(r']*id=["\']1["\']', cells_xml): - errors.append("包含禁止的根单元格 id='1'") - - # 尝试解析 XML - try: - # 包装后解析以验证结构 - wrapped = f"{cells_xml}" - root = ET.fromstring(wrapped) - - # 检查 ID 唯一性 - ids = set() - for cell in root.findall('.//mxCell'): - cell_id = cell.get('id') - if cell_id in ids: - errors.append(f"重复的 ID: {cell_id}") - ids.add(cell_id) - - # 检查必要属性 - if not cell.get('parent'): - errors.append(f"单元格 {cell_id} 缺少 parent 属性") - - except ET.ParseError as e: - errors.append(f"XML 解析错误: {str(e)}") - - return len(errors) == 0, errors - - -def sanitize_cells_xml(cells_xml: str) -> str: - """ - Clean mxCell XML output to reduce common rendering failures. - """ - if not cells_xml: - return "" - - cleaned = cells_xml.strip() - - # Remove markdown code fences if present. - if cleaned.startswith("```"): - cleaned = re.sub(r"^```[a-zA-Z]*\n?", "", cleaned) - cleaned = re.sub(r"```$", "", cleaned) - cleaned = cleaned.strip() - - # Strip XML comments. - cleaned = re.sub(r"", "", cleaned, flags=re.DOTALL).strip() - - # Escape bare ampersands (keep valid entities). - cleaned = re.sub( - r"&(?!amp;|lt;|gt;|quot;|apos;|#\d+;|#x[0-9A-Fa-f]+;)", - "&", - cleaned, - ) - - return cleaned.strip() - - -def _iter_vertices(root: ET.Element) -> Iterable[ET.Element]: - for cell in root.findall('.//mxCell'): - if cell.get('vertex') == '1': - yield cell - - -def _get_geometry(cell: ET.Element) -> Tuple[ET.Element, float, float, float, float]: - geom = cell.find('mxGeometry') - if geom is None: - geom = ET.SubElement(cell, 'mxGeometry') - geom.set('as', 'geometry') - x = float(geom.get('x', '0') or 0) - y = float(geom.get('y', '0') or 0) - w = float(geom.get('width', '120') or 120) - h = float(geom.get('height', '60') or 60) - return geom, x, y, w, h - - -def _overlaps(a: Tuple[float, float, float, float], b: Tuple[float, float, float, float], padding: float) -> bool: - ax, ay, aw, ah = a - bx, by, bw, bh = b - return not ( - ax + aw + padding <= bx or - bx + bw + padding <= ax or - ay + ah + padding <= by or - by + bh + padding <= ay - ) - - -def _style_has(cell: ET.Element, token: str) -> bool: - style = (cell.get('style') or "") - return token in style - - -def _set_geometry(geom: ET.Element, x: float, y: float, w: float, h: float) -> None: - geom.set('x', f"{x:.0f}") - geom.set('y', f"{y:.0f}") - geom.set('width', f"{w:.0f}") - geom.set('height', f"{h:.0f}") - - -def _center(cell: ET.Element) -> Tuple[float, float]: - _, x, y, w, h = _get_geometry(cell) - return x + w / 2.0, y + h / 2.0 - - -def _grid_layout( - cells: List[ET.Element], - start_x: float, - start_y: float, - max_width: float, - gap_x: float, - gap_y: float, -) -> Tuple[float, float]: - if not cells: - return start_x, start_y - cols = max(1, int(max_width // (gap_x + 1))) - col = 0 - row = 0 - max_x = start_x - max_y = start_y - for cell in cells: - geom, _, _, w, h = _get_geometry(cell) - x = start_x + col * gap_x - y = start_y + row * gap_y - _set_geometry(geom, x, y, w, h) - max_x = max(max_x, x + w) - max_y = max(max_y, y + h) - col += 1 - if col >= cols: - col = 0 - row += 1 - return max_x, max_y - - -def _layout_children_in_container( - container: ET.Element, - children: List[ET.Element], - padding: float, - gap: float -) -> None: - if not children: - return - geom, x, y, w, h = _get_geometry(container) - inner_x = x + padding - inner_y = y + padding - max_width = max(120.0, w - padding * 2) - max_x, max_y = _grid_layout( - children, - inner_x, - inner_y, - max_width=max_width, - gap_x=max(140.0, gap), - gap_y=max(90.0, gap), - ) - needed_w = max(w, (max_x - x) + padding) - needed_h = max(h, (max_y - y) + padding) - _set_geometry(geom, x, y, needed_w, needed_h) - - -def _layout_top_level( - diagram_type: str, - vertices: List[ET.Element], - margin: float, - canvas_width: float, - canvas_height: float, - gap: float -) -> None: - if not vertices: - return - diagram_type = (diagram_type or "auto").lower() - if diagram_type == "flowchart": - x = margin - y = margin - for cell in vertices: - geom, _, _, w, h = _get_geometry(cell) - _set_geometry(geom, x, y, w, h) - y += h + gap - return - - if diagram_type == "sequence": - x = margin - y = margin - for cell in vertices: - geom, _, _, w, h = _get_geometry(cell) - _set_geometry(geom, x, y, w, h) - x += max(w + gap, 160) - return - - if diagram_type == "mindmap": - root = max(vertices, key=lambda c: (_get_geometry(c)[2] * _get_geometry(c)[3])) - others = [v for v in vertices if v is not root] - geom, _, _, w, h = _get_geometry(root) - center_x = canvas_width / 2.0 - w / 2.0 - center_y = canvas_height / 2.0 - h / 2.0 - _set_geometry(geom, center_x, center_y, w, h) - if not others: - return - import math - radius = max(200.0, min(canvas_width, canvas_height) / 3.0) - for idx, cell in enumerate(others): - angle = (2 * math.pi * idx) / max(1, len(others)) - geom, _, _, cw, ch = _get_geometry(cell) - x = center_x + w / 2.0 + math.cos(angle) * radius - cw / 2.0 - y = center_y + h / 2.0 + math.sin(angle) * radius - ch / 2.0 - _set_geometry(geom, x, y, cw, ch) - return - - if diagram_type == "er": - _grid_layout(vertices, margin, margin, canvas_width - margin * 2, gap_x=220, gap_y=160) - return - - _grid_layout(vertices, margin, margin, canvas_width - margin * 2, gap_x=220, gap_y=160) - - -def _add_edge_waypoints(root: ET.Element, cell_by_id: Dict[str, ET.Element]) -> None: - for cell in root.findall('.//mxCell'): - if cell.get('edge') != '1': - continue - source = cell.get('source') - target = cell.get('target') - if not source or not target: - continue - s_cell = cell_by_id.get(source) - t_cell = cell_by_id.get(target) - if s_cell is None or t_cell is None: - continue - geom = cell.find('mxGeometry') - if geom is None: - geom = ET.SubElement(cell, 'mxGeometry') - geom.set('relative', '1') - geom.set('as', 'geometry') - # Skip if already has waypoints - if geom.findall('mxPoint') or geom.find("Array[@as='points']") is not None: - continue - sx, sy = _center(s_cell) - tx, ty = _center(t_cell) - mid_x = (sx + tx) / 2.0 - # Add two points to encourage orthogonal routing (must be inside Array as="points") - points = ET.SubElement(geom, 'Array') - points.set('as', 'points') - p1 = ET.SubElement(points, 'mxPoint') - p1.set('x', f"{mid_x:.0f}") - p1.set('y', f"{sy:.0f}") - p2 = ET.SubElement(points, 'mxPoint') - p2.set('x', f"{mid_x:.0f}") - p2.set('y', f"{ty:.0f}") - - -def resolve_overlaps( - cells_xml: str, - diagram_type: str = "auto", - canvas_width: float = 800, - canvas_height: float = 600, - margin: float = 40, - gap: float = 60, - max_attempts: int = 200 -) -> str: - """ - Resolve vertex overlaps by incrementally shifting nodes on a simple grid. - - This is a conservative post-process: it preserves original order and only moves - nodes when overlaps are detected. - """ - if not cells_xml: - return "" - - try: - wrapped = f"{cells_xml}" - root = ET.fromstring(wrapped) - except ET.ParseError: - return cells_xml - - vertices = list(_iter_vertices(root)) - if not vertices: - return cells_xml - cell_by_id: Dict[str, ET.Element] = {cell.get('id'): cell for cell in root.findall('.//mxCell') if cell.get('id')} - - children_by_parent: Dict[str, List[ET.Element]] = {} - for cell in vertices: - parent = cell.get('parent') or "1" - children_by_parent.setdefault(parent, []).append(cell) - - # Layout children inside containers first - for parent_id, children in list(children_by_parent.items()): - if parent_id == "1": - continue - container = cell_by_id.get(parent_id) - if container is None: - continue - if container.get('vertex') == '1' or _style_has(container, "swimlane"): - _layout_children_in_container(container, children, padding=20, gap=gap) - - # Layout top-level vertices by diagram type - top_level = [c for c in vertices if (c.get('parent') or "1") == "1"] - _layout_top_level(diagram_type, top_level, margin, canvas_width, canvas_height, gap) - - # Final overlap pass (conservative) - placed: List[Tuple[float, float, float, float]] = [] - for cell in vertices: - geom, x, y, w, h = _get_geometry(cell) - attempts = 0 - while any(_overlaps((x, y, w, h), other, gap / 2) for other in placed): - attempts += 1 - x += gap / 2 - if x + w > canvas_width - margin: - x = margin - y += gap / 2 - if attempts >= max_attempts: - break - _set_geometry(geom, x, y, w, h) - placed.append((x, y, w, h)) - - _add_edge_waypoints(root, cell_by_id) - - result_cells = [ET.tostring(cell, encoding='unicode') for cell in root.findall('mxCell')] - return '\n'.join(result_cells) - - -def export_drawio_png( - cells_xml: str, - output_path: str, - drawio_bin: Optional[str] = None, - timeout: int = 60, -) -> Tuple[bool, str]: - """ - Render draw.io XML to PNG using draw.io CLI (server-side renderer). - - Requires draw.io/diagrams.net CLI installed and available in PATH. - - Returns: - (ok, message) - ok True if PNG created, else False with error message. - """ - if not cells_xml: - return False, "empty xml" - - output_file = Path(output_path).resolve() - output_file.parent.mkdir(parents=True, exist_ok=True) - - full_xml = cells_xml if " Tuple[str, List[str]]: - """ - 应用编辑操作到现有 XML - - Args: - current_xml: 当前的 mxCell XML - operations: 编辑操作列表 - [{"operation": "update|add|delete", "cell_id": "...", "new_xml": "..."}] - - Returns: - (new_xml, errors) 元组 - """ - errors = [] - - try: - # 包装后解析 - wrapped = f"{current_xml}" - root = ET.fromstring(wrapped) - - for op in operations: - operation = op.get('operation') - cell_id = op.get('cell_id') - new_xml = op.get('new_xml', '') - - if operation == 'delete': - # 删除单元格 - cell = root.find(f".//mxCell[@id='{cell_id}']") - if cell is not None: - parent = root.find(f".//mxCell[@id='{cell_id}']/..") - if parent is not None: - parent.remove(cell) - else: - errors.append(f"未找到要删除的单元格: {cell_id}") - - elif operation == 'update': - # 更新单元格 - cell = root.find(f".//mxCell[@id='{cell_id}']") - if cell is not None: - parent = root.find(f".//mxCell[@id='{cell_id}']/..") - idx = list(parent).index(cell) - parent.remove(cell) - new_cell = ET.fromstring(new_xml) - parent.insert(idx, new_cell) - else: - errors.append(f"未找到要更新的单元格: {cell_id}") - - elif operation == 'add': - # 添加新单元格 - new_cell = ET.fromstring(new_xml) - root.append(new_cell) - else: - errors.append(f"未知操作: {operation}") - - # 提取结果 - result_cells = [] - for cell in root.findall('mxCell'): - result_cells.append(ET.tostring(cell, encoding='unicode')) - - return '\n'.join(result_cells), errors - - except ET.ParseError as e: - errors.append(f"XML 解析错误: {str(e)}") - return current_xml, errors - - -def get_cell_ids(cells_xml: str) -> List[str]: - """获取所有单元格 ID""" - ids = [] - try: - wrapped = f"{cells_xml}" - root = ET.fromstring(wrapped) - for cell in root.findall('.//mxCell'): - cell_id = cell.get('id') - if cell_id: - ids.append(cell_id) - except ET.ParseError: - # 用正则提取 - pattern = r']*id=["\']([^"\']+)["\']' - ids = re.findall(pattern, cells_xml) - return ids - - -def generate_next_id(cells_xml: str) -> str: - """生成下一个可用的 ID""" - ids = get_cell_ids(cells_xml) - max_num = 1 - for id_str in ids: - try: - num = int(id_str) - max_num = max(max_num, num) - except ValueError: - continue - return str(max_num + 1) diff --git a/dataflow_agent/toolkits/filetool/__init__.py b/dataflow_agent/toolkits/filetool/__init__.py deleted file mode 100644 index ab70b0d..0000000 --- a/dataflow_agent/toolkits/filetool/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -文件工具模块 - -提供文件内容读取和目录内容查看功能。 -""" -from dataflow_agent.toolkits.filetool.filetools import ( - # LangChain Tool 封装 - read_text_file, - list_directory, - # 本地工具函数 - local_tool_read_file, - local_tool_list_directory, - # 底层函数 - read_file_content, - list_directory_content, -) - -__all__ = [ - "read_text_file", - "list_directory", - "local_tool_read_file", - "local_tool_list_directory", - "read_file_content", - "list_directory_content", -] diff --git a/dataflow_agent/toolkits/filetool/filetools.py b/dataflow_agent/toolkits/filetool/filetools.py deleted file mode 100644 index cc27b98..0000000 --- a/dataflow_agent/toolkits/filetool/filetools.py +++ /dev/null @@ -1,495 +0,0 @@ -""" -文件操作工具模块 - -提供文件内容读取和目录内容查看功能,支持跨平台(Windows/Linux)。 -所有操作以项目根目录为边界,不允许访问根目录之外的文件。 -""" -from __future__ import annotations - -import os -import platform -import subprocess -from pathlib import Path -from typing import Optional, List, Union, Dict, Any - -from langchain_core.tools import tool -from dataflow_agent.logger import get_logger -import dataflow_agent.utils as utils - -log = get_logger(__name__) - -# 获取项目根目录 -PROJECT_ROOT = utils.get_project_root() - - -def _is_path_within_project(path: Path) -> bool: - """ - 检查路径是否在项目根目录内 - - Args: - path: 要检查的路径 - - Returns: - bool: 如果路径在项目根目录内返回 True,否则返回 False - """ - try: - resolved_path = path.resolve() - resolved_root = PROJECT_ROOT.resolve() - # 检查路径是否以项目根目录开头 - return str(resolved_path).startswith(str(resolved_root)) - except Exception as e: - log.warning(f"路径检查失败: {e}") - return False - - -def _resolve_path(path_str: str) -> Path: - """ - 解析路径,支持相对路径和绝对路径 - - Args: - path_str: 路径字符串 - - Returns: - Path: 解析后的路径对象 - """ - path = Path(path_str) - if not path.is_absolute(): - # 相对路径基于项目根目录 - path = PROJECT_ROOT / path - return path.resolve() - - -def read_file_content( - file_path: str, - start_line: Optional[int] = None, - end_line: Optional[int] = None, - encoding: str = "utf-8" -) -> Dict[str, Any]: - """ - 读取文件内容 - - Args: - file_path: 文件路径(相对于项目根目录或绝对路径) - start_line: 起始行号(从1开始,可选) - end_line: 结束行号(包含,可选) - encoding: 文件编码,默认 utf-8 - - Returns: - Dict 包含: - - success: 是否成功 - - content: 文件内容(成功时) - - total_lines: 文件总行数(成功时) - - read_lines: 实际读取的行范围 [start, end](成功时) - - error: 错误信息(失败时) - """ - try: - path = _resolve_path(file_path) - - # 安全检查:确保路径在项目根目录内 - if not _is_path_within_project(path): - return { - "success": False, - "error": f"安全限制:不允许访问项目根目录之外的文件。项目根目录: {PROJECT_ROOT}" - } - - # 检查文件是否存在 - if not path.exists(): - return { - "success": False, - "error": f"文件不存在: {path}" - } - - # 检查是否为文件 - if not path.is_file(): - return { - "success": False, - "error": f"路径不是文件: {path}" - } - - # 读取文件内容 - try: - with open(path, 'r', encoding=encoding) as f: - lines = f.readlines() - except UnicodeDecodeError: - return { - "success": False, - "error": f"文件编码错误,无法使用 {encoding} 编码读取。请尝试其他编码或确认文件为文本文件。" - } - - total_lines = len(lines) - - # 处理行号范围 - actual_start = 1 - actual_end = total_lines - - if start_line is not None: - if start_line < 1: - start_line = 1 - actual_start = min(start_line, total_lines) - - if end_line is not None: - if end_line < 1: - end_line = 1 - actual_end = min(end_line, total_lines) - - # 确保 start <= end - if actual_start > actual_end: - actual_start, actual_end = actual_end, actual_start - - # 提取指定行范围的内容(行号从1开始,索引从0开始) - selected_lines = lines[actual_start - 1:actual_end] - content = ''.join(selected_lines) - - log.info(f"[read_file_content] 成功读取文件: {path}, 行范围: {actual_start}-{actual_end}/{total_lines}") - - return { - "success": True, - "content": content, - "total_lines": total_lines, - "read_lines": [actual_start, actual_end], - "file_path": str(path) - } - - except Exception as e: - log.error(f"[read_file_content] 读取文件失败: {e}") - return { - "success": False, - "error": str(e) - } - - -def list_directory_content( - dir_path: str, - show_hidden: bool = False, - recursive: bool = False, - max_depth: int = 1 -) -> Dict[str, Any]: - """ - 查看目录内容,支持 Windows 和 Linux - - Args: - dir_path: 目录路径(相对于项目根目录或绝对路径) - show_hidden: 是否显示隐藏文件,默认 False - recursive: 是否递归显示子目录,默认 False - max_depth: 递归最大深度(仅当 recursive=True 时有效),默认 1 - - Returns: - Dict 包含: - - success: 是否成功 - - content: 目录内容(成功时) - - path: 目录绝对路径(成功时) - - error: 错误信息(失败时) - """ - try: - path = _resolve_path(dir_path) - - # 安全检查:确保路径在项目根目录内 - if not _is_path_within_project(path): - return { - "success": False, - "error": f"安全限制:不允许访问项目根目录之外的目录。项目根目录: {PROJECT_ROOT}" - } - - # 检查目录是否存在 - if not path.exists(): - return { - "success": False, - "error": f"目录不存在: {path}" - } - - # 检查是否为目录 - if not path.is_dir(): - return { - "success": False, - "error": f"路径不是目录: {path}" - } - - # 根据操作系统选择命令 - system = platform.system().lower() - - if system == "windows": - # Windows 使用 dir 命令 - cmd = ["cmd", "/c", "dir"] - if show_hidden: - cmd.append("/a") # 显示所有文件包括隐藏文件 - if recursive: - cmd.append("/s") # 递归显示 - cmd.append(str(path)) - else: - # Linux/macOS 使用 ls 命令 - cmd = ["ls", "-l"] - if show_hidden: - cmd.append("-a") # 显示隐藏文件 - if recursive: - cmd.append("-R") # 递归显示 - cmd.append(str(path)) - - # 执行命令 - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=30, - cwd=str(path.parent) # 在父目录执行,避免路径问题 - ) - - output = result.stdout - if result.returncode != 0 and result.stderr: - output += f"\n[stderr]: {result.stderr}" - - except subprocess.TimeoutExpired: - return { - "success": False, - "error": "命令执行超时(30秒)" - } - except FileNotFoundError as e: - # 如果命令不存在,使用 Python 原生方式 - log.warning(f"系统命令不可用,使用 Python 原生方式: {e}") - output = _list_directory_python(path, show_hidden, recursive, max_depth) - - log.info(f"[list_directory_content] 成功列出目录: {path}") - - return { - "success": True, - "content": output, - "path": str(path), - "system": system - } - - except Exception as e: - log.error(f"[list_directory_content] 列出目录失败: {e}") - return { - "success": False, - "error": str(e) - } - - -def _list_directory_python( - path: Path, - show_hidden: bool = False, - recursive: bool = False, - max_depth: int = 1, - current_depth: int = 0, - prefix: str = "" -) -> str: - """ - 使用 Python 原生方式列出目录内容(备用方案) - - Args: - path: 目录路径 - show_hidden: 是否显示隐藏文件 - recursive: 是否递归 - max_depth: 最大递归深度 - current_depth: 当前深度 - prefix: 输出前缀(用于缩进) - - Returns: - str: 格式化的目录内容 - """ - lines = [] - - if current_depth == 0: - lines.append(f"目录: {path}") - lines.append("-" * 60) - - try: - entries = sorted(path.iterdir(), key=lambda x: (not x.is_dir(), x.name.lower())) - - for entry in entries: - # 跳过隐藏文件(如果不显示) - if not show_hidden and entry.name.startswith('.'): - continue - - # 获取文件信息 - try: - stat = entry.stat() - size = stat.st_size - is_dir = entry.is_dir() - - # 格式化大小 - if is_dir: - size_str = "" - elif size < 1024: - size_str = f"{size}B" - elif size < 1024 * 1024: - size_str = f"{size / 1024:.1f}KB" - else: - size_str = f"{size / (1024 * 1024):.1f}MB" - - # 格式化输出 - type_indicator = "📁" if is_dir else "📄" - lines.append(f"{prefix}{type_indicator} {entry.name:<40} {size_str:>10}") - - # 递归处理子目录 - if recursive and is_dir and current_depth < max_depth: - sub_content = _list_directory_python( - entry, - show_hidden, - recursive, - max_depth, - current_depth + 1, - prefix + " " - ) - lines.append(sub_content) - - except PermissionError: - lines.append(f"{prefix}⚠️ {entry.name:<40} [权限不足]") - except Exception as e: - lines.append(f"{prefix}⚠️ {entry.name:<40} [错误: {e}]") - - except PermissionError: - lines.append(f"{prefix}[权限不足,无法读取目录]") - except Exception as e: - lines.append(f"{prefix}[错误: {e}]") - - return "\n".join(lines) - - -# ==================== LangChain Tool 封装 ==================== - -@tool -def read_text_file( - file_path: str, - start_line: Optional[int] = None, - end_line: Optional[int] = None -) -> str: - """ - 读取文本文件内容。 - - 支持读取项目内的任意文本文件,可指定读取的行范围。 - 出于安全考虑,只能读取项目根目录内的文件。 - - Args: - file_path: 文件路径,可以是相对路径(相对于项目根目录)或绝对路径 - start_line: 起始行号(从1开始,可选)。不指定则从第1行开始 - end_line: 结束行号(包含,可选)。不指定则读取到文件末尾 - - Returns: - JSON 格式的结果,包含文件内容或错误信息 - - Examples: - >>> read_text_file("README.md") # 读取整个文件 - >>> read_text_file("src/main.py", start_line=10, end_line=20) # 读取第10-20行 - >>> read_text_file("config.yaml", end_line=50) # 读取前50行 - """ - import json - result = read_file_content(file_path, start_line, end_line) - return json.dumps(result, ensure_ascii=False, indent=2) - - -@tool -def list_directory( - dir_path: str, - show_hidden: bool = False, - recursive: bool = False -) -> str: - """ - 查看目录内容。 - - 列出指定目录下的文件和子目录,支持 Windows 和 Linux 系统。 - 出于安全考虑,只能查看项目根目录内的目录。 - - Args: - dir_path: 目录路径,可以是相对路径(相对于项目根目录)或绝对路径 - show_hidden: 是否显示隐藏文件(以.开头的文件),默认 False - recursive: 是否递归显示子目录内容,默认 False - - Returns: - JSON 格式的结果,包含目录内容或错误信息 - - Examples: - >>> list_directory(".") # 列出项目根目录 - >>> list_directory("src", show_hidden=True) # 列出 src 目录,包含隐藏文件 - >>> list_directory("dataflow_agent", recursive=True) # 递归列出目录 - """ - import json - result = list_directory_content(dir_path, show_hidden, recursive) - return json.dumps(result, ensure_ascii=False, indent=2) - - -# ==================== 直接调用的函数接口 ==================== - -def local_tool_read_file( - file_path: str, - start_line: Optional[int] = None, - end_line: Optional[int] = None, - encoding: str = "utf-8" -) -> Dict[str, Any]: - """ - 本地工具:读取文件内容 - - 直接返回字典结果,适合在代码中直接调用。 - - Args: - file_path: 文件路径 - start_line: 起始行号(从1开始,可选) - end_line: 结束行号(包含,可选) - encoding: 文件编码 - - Returns: - Dict: 包含 success, content/error 等字段 - """ - return read_file_content(file_path, start_line, end_line, encoding) - - -def local_tool_list_directory( - dir_path: str, - show_hidden: bool = False, - recursive: bool = False, - max_depth: int = 1 -) -> Dict[str, Any]: - """ - 本地工具:列出目录内容 - - 直接返回字典结果,适合在代码中直接调用。 - - Args: - dir_path: 目录路径 - show_hidden: 是否显示隐藏文件 - recursive: 是否递归 - max_depth: 递归最大深度 - - Returns: - Dict: 包含 success, content/error 等字段 - """ - return list_directory_content(dir_path, show_hidden, recursive, max_depth) - - -# ==================== 测试代码 ==================== - -if __name__ == "__main__": - import json - - print("=" * 60) - print("文件工具测试") - print("=" * 60) - - # 测试1:读取文件 - print("\n--- 测试1:读取 README.md ---") - result = read_file_content("README.md", end_line=10) - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # 测试2:读取指定行范围 - print("\n--- 测试2:读取文件指定行 ---") - result = read_file_content("dataflow_agent/utils.py", start_line=1, end_line=20) - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # 测试3:列出目录 - print("\n--- 测试3:列出项目根目录 ---") - result = list_directory_content(".") - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # 测试4:递归列出目录 - print("\n--- 测试4:递归列出 toolkits 目录 ---") - result = list_directory_content("dataflow_agent/toolkits", recursive=True) - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # 测试5:安全检查 - 尝试访问项目外的路径 - print("\n--- 测试5:安全检查(访问项目外路径)---") - result = read_file_content("/etc/passwd") - print(json.dumps(result, ensure_ascii=False, indent=2)) - - print("\n" + "=" * 60) - print("测试完成") - print("=" * 60) diff --git a/dataflow_agent/toolkits/image2drawio/__init__.py b/dataflow_agent/toolkits/image2drawio/__init__.py deleted file mode 100644 index b70cabc..0000000 --- a/dataflow_agent/toolkits/image2drawio/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Image2DrawIO toolkit utilities.""" - -from .utils import ( - classify_shape, - extract_text_color, - mask_to_bbox, - normalize_mask, - sample_fill_stroke, - save_masked_rgba, - bbox_iou_px, -) - -__all__ = [ - "classify_shape", - "extract_text_color", - "mask_to_bbox", - "normalize_mask", - "sample_fill_stroke", - "save_masked_rgba", - "bbox_iou_px", -] diff --git a/dataflow_agent/toolkits/image2drawio/utils.py b/dataflow_agent/toolkits/image2drawio/utils.py deleted file mode 100644 index 28de201..0000000 --- a/dataflow_agent/toolkits/image2drawio/utils.py +++ /dev/null @@ -1,196 +0,0 @@ -from __future__ import annotations - -from typing import List, Tuple, Optional -import os -import math - -import cv2 -import numpy as np - - -def normalize_mask(mask: np.ndarray, target_shape: Tuple[int, int]) -> np.ndarray: - """Ensure mask is boolean and matches target (H, W).""" - if mask is None: - raise ValueError("mask is None") - if mask.dtype != np.bool_: - mask = mask.astype(bool) - - h, w = target_shape - if mask.shape[0] != h or mask.shape[1] != w: - mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST) - mask = mask.astype(bool) - return mask - - -def mask_to_bbox(mask: np.ndarray) -> Optional[List[int]]: - ys, xs = np.where(mask) - if xs.size == 0 or ys.size == 0: - return None - x1 = int(xs.min()) - x2 = int(xs.max()) - y1 = int(ys.min()) - y2 = int(ys.max()) - if x2 <= x1 or y2 <= y1: - return None - return [x1, y1, x2, y2] - - -def classify_shape(mask: np.ndarray) -> Tuple[str, float]: - """ - Heuristic shape classification. - - Returns (shape_type, confidence) where shape_type in: - rect | rounded_rect | ellipse | diamond | unknown - """ - cnts, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - if not cnts: - return "unknown", 0.0 - - cnt = max(cnts, key=cv2.contourArea) - area = cv2.contourArea(cnt) - if area < 50: - return "unknown", 0.0 - - peri = cv2.arcLength(cnt, True) - if peri <= 1e-6: - return "unknown", 0.0 - - approx = cv2.approxPolyDP(cnt, 0.02 * peri, True) - vertices = len(approx) - - x, y, w, h = cv2.boundingRect(cnt) - bbox_area = float(w * h) if w > 0 and h > 0 else 1.0 - area_ratio = float(area) / bbox_area - aspect = float(w) / float(h) if h > 0 else 1.0 - - circularity = 4.0 * math.pi * float(area) / (float(peri) * float(peri) + 1e-6) - - # Ellipse / circle - if circularity > 0.75: - return "ellipse", min(1.0, circularity) - - # Quadrilaterals - if vertices == 4: - if area_ratio < 0.7: - return "diamond", 0.8 - return "rect", 0.8 - - # Rounded rectangle: many vertices, high area ratio - if vertices >= 5 and area_ratio > 0.75 and 0.4 < aspect < 2.5: - return "rounded_rect", 0.6 - - return "unknown", 0.0 - - -def _to_hex(color_bgr: Tuple[int, int, int]) -> str: - b, g, r = [int(max(0, min(255, c))) for c in color_bgr] - return f"#{r:02x}{g:02x}{b:02x}" - - -def sample_fill_stroke(image_bgr: np.ndarray, mask: np.ndarray) -> Tuple[str, str]: - """ - Sample fill & stroke colors from original image using the mask. - Returns (fill_hex, stroke_hex). - """ - h, w = image_bgr.shape[:2] - mask = normalize_mask(mask, (h, w)) - - # Edge (stroke): dilate - erode - k = max(1, int(min(h, w) * 0.002)) - k = min(k, 5) - kernel = np.ones((k, k), np.uint8) - dil = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1) - ero = cv2.erode(mask.astype(np.uint8), kernel, iterations=1) - edge = (dil > 0) & (ero == 0) - - stroke_pixels = image_bgr[edge] - if stroke_pixels.size == 0: - stroke_pixels = image_bgr[mask] - - # Select darkest quartile by luminance - if stroke_pixels.size > 0: - rgb = stroke_pixels[:, ::-1].astype(np.float32) - lum = 0.2126 * rgb[:, 0] + 0.7152 * rgb[:, 1] + 0.0722 * rgb[:, 2] - if lum.size > 10: - thresh = np.percentile(lum, 25) - sel = stroke_pixels[lum <= thresh] - else: - sel = stroke_pixels - stroke = tuple(np.mean(sel, axis=0).tolist()) - else: - stroke = (0, 0, 0) - - # Fill: erode mask to remove border - erode_k = max(1, int(min(h, w) * 0.004)) - erode_k = min(erode_k, 7) - kernel2 = np.ones((erode_k, erode_k), np.uint8) - inner = cv2.erode(mask.astype(np.uint8), kernel2, iterations=1) > 0 - fill_pixels = image_bgr[inner] - if fill_pixels.size == 0: - fill_pixels = image_bgr[mask] - - if fill_pixels.size > 0: - fill = tuple(np.median(fill_pixels, axis=0).tolist()) - else: - fill = (255, 255, 255) - - return _to_hex(fill), _to_hex(stroke) - - -def extract_text_color(image_bgr: np.ndarray, bbox_px: List[int]) -> str: - x1, y1, x2, y2 = bbox_px - x1 = max(0, min(image_bgr.shape[1] - 1, int(x1))) - x2 = max(0, min(image_bgr.shape[1], int(x2))) - y1 = max(0, min(image_bgr.shape[0] - 1, int(y1))) - y2 = max(0, min(image_bgr.shape[0], int(y2))) - if x2 <= x1 or y2 <= y1: - return "#000000" - region = image_bgr[y1:y2, x1:x2] - if region.size == 0: - return "#000000" - rgb = region[:, :, ::-1].reshape(-1, 3).astype(np.float32) - lum = 0.2126 * rgb[:, 0] + 0.7152 * rgb[:, 1] + 0.0722 * rgb[:, 2] - if lum.size == 0: - return "#000000" - thresh = np.percentile(lum, 25) - sel = rgb[lum <= thresh] - if sel.size == 0: - sel = rgb - color = tuple(np.mean(sel, axis=0).tolist()) - r, g, b = [int(max(0, min(255, c))) for c in color] - return f"#{r:02x}{g:02x}{b:02x}" - - -def save_masked_rgba(image_bgr: np.ndarray, mask: np.ndarray, out_path: str) -> str: - """Save masked region as RGBA PNG with alpha channel.""" - h, w = image_bgr.shape[:2] - mask = normalize_mask(mask, (h, w)) - rgba = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2BGRA) - alpha = (mask.astype(np.uint8) * 255) - rgba[:, :, 3] = alpha - - bbox = mask_to_bbox(mask) - if bbox: - x1, y1, x2, y2 = bbox - x2 = min(w, x2 + 1) - y2 = min(h, y2 + 1) - crop = rgba[y1:y2, x1:x2] - else: - crop = rgba - - os.makedirs(os.path.dirname(out_path), exist_ok=True) - cv2.imwrite(out_path, crop) - return out_path - - -def bbox_iou_px(a: List[int], b: List[int]) -> float: - xA = max(a[0], b[0]) - yA = max(a[1], b[1]) - xB = min(a[2], b[2]) - yB = min(a[3], b[3]) - inter = max(0, xB - xA) * max(0, yB - yA) - area_a = max(0, a[2] - a[0]) * max(0, a[3] - a[1]) - area_b = max(0, b[2] - b[0]) * max(0, b[3] - b[1]) - if area_a == 0 or area_b == 0: - return 0.0 - return inter / float(area_a + area_b - inter + 1e-6) diff --git a/dataflow_agent/toolkits/model_servers/__init__.py b/dataflow_agent/toolkits/model_servers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/model_servers/generic_lb.py b/dataflow_agent/toolkits/model_servers/generic_lb.py deleted file mode 100644 index b1519e9..0000000 --- a/dataflow_agent/toolkits/model_servers/generic_lb.py +++ /dev/null @@ -1,77 +0,0 @@ -import uvicorn -import httpx -import argparse -import itertools -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse, Response - -app = FastAPI(title="Generic Model Load Balancer") - -# Default backends (can be overridden by args) -BACKEND_URLS = [] -iterator = None - -@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"]) -async def proxy(request: Request, path_name: str): - global iterator - if not iterator: - return Response("No backends configured", status_code=503) - - # Round-robin selection - target_base = next(iterator) - - # Construct target URL - url = f"{target_base}/{path_name}" - if request.url.query: - url += f"?{request.url.query}" - - # Forward headers (excluding Host to avoid conflicts) - headers = dict(request.headers) - headers.pop("host", None) - headers.pop("content-length", None) # Let httpx handle content-length - - async with httpx.AsyncClient(timeout=None) as client: - try: - # Read body - body = await request.body() - - # Build request - req = client.build_request( - request.method, - url, - headers=headers, - content=body, - timeout=None - ) - - # Send request and stream response - r = await client.send(req, stream=True) - - return StreamingResponse( - r.aiter_raw(), - status_code=r.status_code, - headers=r.headers, - background=None - ) - except Exception as e: - # Simple error handling - import traceback - traceback.print_exc() - return Response(content=f"Proxy Error: {str(e)}", status_code=502) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=8000, help="Port for the load balancer") - parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the load balancer") - parser.add_argument("--name", type=str, default="Load Balancer", help="Name of the service") - parser.add_argument("--backends", nargs="+", required=True, help="List of backend URLs (e.g., http://localhost:8011)") - args = parser.parse_args() - - BACKEND_URLS = args.backends - iterator = itertools.cycle(BACKEND_URLS) - - app.title = args.name - print(f"Starting {args.name} on {args.host}:{args.port}") - print(f"Balancing between backends: {BACKEND_URLS}") - - uvicorn.run(app, host=args.host, port=args.port) diff --git a/dataflow_agent/toolkits/model_servers/mineru_server.py b/dataflow_agent/toolkits/model_servers/mineru_server.py deleted file mode 100644 index ffe17b2..0000000 --- a/dataflow_agent/toolkits/model_servers/mineru_server.py +++ /dev/null @@ -1,75 +0,0 @@ -import uvicorn -import httpx -import argparse -import itertools -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse, Response - -app = FastAPI(title="MinerU Load Balancer") - -# Default backends (can be overridden by args) -BACKEND_URLS = [] -iterator = None - -@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"]) -async def proxy(request: Request, path_name: str): - global iterator - if not iterator: - return Response("No backends configured", status_code=503) - - # Round-robin selection - target_base = next(iterator) - - # Construct target URL - url = f"{target_base}/{path_name}" - if request.url.query: - url += f"?{request.url.query}" - - # Forward headers (excluding Host to avoid conflicts) - headers = dict(request.headers) - headers.pop("host", None) - headers.pop("content-length", None) # Let httpx handle content-length - - async with httpx.AsyncClient(timeout=None) as client: - try: - # Read body - body = await request.body() - - # Build request - req = client.build_request( - request.method, - url, - headers=headers, - content=body, - timeout=None - ) - - # Send request and stream response - r = await client.send(req, stream=True) - - return StreamingResponse( - r.aiter_raw(), - status_code=r.status_code, - headers=r.headers, - background=None - ) - except Exception as e: - # Simple error handling - import traceback - traceback.print_exc() - return Response(content=f"Proxy Error: {str(e)}", status_code=502) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=8010, help="Port for the load balancer") - parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the load balancer") - parser.add_argument("--backends", nargs="+", required=True, help="List of backend URLs (e.g., http://localhost:8011)") - args = parser.parse_args() - - BACKEND_URLS = args.backends - iterator = itertools.cycle(BACKEND_URLS) - - print(f"Starting MinerU Load Balancer on {args.host}:{args.port}") - print(f"Balancing between backends: {BACKEND_URLS}") - - uvicorn.run(app, host=args.host, port=args.port) diff --git a/dataflow_agent/toolkits/model_servers/ocr_server.py b/dataflow_agent/toolkits/model_servers/ocr_server.py deleted file mode 100644 index c2a939a..0000000 --- a/dataflow_agent/toolkits/model_servers/ocr_server.py +++ /dev/null @@ -1,77 +0,0 @@ -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from typing import Optional, List, Any, Tuple, Union -import os -import sys - -# Add project root to path -current_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.abspath(os.path.join(current_dir, "../../../")) -if project_root not in sys.path: - sys.path.insert(0, project_root) - -from dataflow_agent.toolkits.multimodaltool.ppt_tool import paddle_ocr_page_with_layout - -app = FastAPI(title="OCR Model Server") - -class OCRRequest(BaseModel): - image_path: str - -class OCRLine(BaseModel): - bbox: List[float] # [x1, y1, x2, y2] - text: str - conf: float - -class OCRResponse(BaseModel): - image_size: Tuple[int, int] - lines: List[OCRLine] - body_h_px: Optional[float] = None - bg_color: Optional[Tuple[int, int, int]] = None - -@app.post("/predict", response_model=OCRResponse) -async def predict(req: OCRRequest): - """ - Run PaddleOCR on the given image path and analyze layout. - """ - if not os.path.exists(req.image_path): - raise HTTPException(status_code=404, detail=f"Image path not found: {req.image_path}") - - try: - # 调用本地 ppt_tool 函数 - # 注意:PADDLE_OCR 是全局初始化的,所以服务启动时加载一次 - result = paddle_ocr_page_with_layout(req.image_path) - - # result structure: - # { - # "image_size": (w, h), - # "lines": [(bbox, text, conf), ...], - # "body_h_px": float/None, - # "bg_color": (r,g,b)/None, - # } - - # Transform lines format to match Pydantic model - # from tuple to dict/object - transformed_lines = [] - for line in result.get("lines", []): - bbox, text, conf = line - transformed_lines.append(OCRLine( - bbox=bbox, - text=text, - conf=conf - )) - - return OCRResponse( - image_size=result.get("image_size", (0, 0)), - lines=transformed_lines, - body_h_px=result.get("body_h_px"), - bg_color=result.get("bg_color") - ) - - except Exception as e: - import traceback - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/health") -def health(): - return {"status": "ok"} diff --git a/dataflow_agent/toolkits/model_servers/sam3_server.py b/dataflow_agent/toolkits/model_servers/sam3_server.py deleted file mode 100644 index 81aeeba..0000000 --- a/dataflow_agent/toolkits/model_servers/sam3_server.py +++ /dev/null @@ -1,312 +0,0 @@ -from __future__ import annotations - -import argparse -import asyncio -import base64 -import io -import os -from collections import OrderedDict -from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple - -import cv2 -import numpy as np -import torch -import uvicorn -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel, Field -from PIL import Image - - -def _ensure_sam3_importable() -> None: - import sys - - candidates = [] - sam3_home = os.environ.get("SAM3_HOME", "").strip() - if sam3_home: - candidates.append(Path(sam3_home)) - - project_root = Path(__file__).resolve().parents[3] - candidates.append(project_root / "models" / "sam3-official" / "sam3") - candidates.append(Path("/data/users/pzw/models/sam3-official/sam3")) - - for path in candidates: - if path.exists() and path.is_dir(): - p = str(path.resolve()) - if p not in sys.path: - sys.path.insert(0, p) - - -_ensure_sam3_importable() -from sam3.model_builder import build_sam3_image_model # type: ignore -from sam3.model.sam3_image_processor import Sam3Processor # type: ignore - - -class PredictRequest(BaseModel): - image_path: str = Field(..., description="Path to the image that the server can read") - prompts: List[str] = Field(..., min_items=1, description="Text prompts for SAM3") - return_masks: bool = Field(False, description="Whether to return mask data") - mask_format: Literal["rle", "png"] = Field( - "rle", description="Mask format: run-length encoding or base64 png" - ) - score_threshold: Optional[float] = Field(None, description="Override score threshold") - epsilon_factor: Optional[float] = Field(None, description="Override polygon epsilon factor") - min_area: Optional[int] = Field(None, description="Override minimum polygon area") - - -class PredictResponse(BaseModel): - image_size: Dict[str, int] - results: List[Dict] - - -def _encode_mask_rle(mask: np.ndarray) -> str: - flat = mask.reshape(-1).astype(np.uint8) - runs: List[int] = [] - last_val = flat[0] - length = 1 - for val in flat[1:]: - if val == last_val: - length += 1 - else: - runs.append(length) - length = 1 - last_val = val - runs.append(length) - return ",".join(str(x) for x in runs) - - -def _encode_mask_png(mask: np.ndarray) -> str: - buffer = io.BytesIO() - img = Image.fromarray(mask.astype(np.uint8)) - img.save(buffer, format="PNG") - buffer.seek(0) - return base64.b64encode(buffer.read()).decode("ascii") - - -def _extract_polygon(binary_mask: np.ndarray, epsilon_factor: float) -> Tuple[List[List[int]], float]: - contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - if not contours: - return [], 0.0 - - max_cnt = None - max_area = 0.0 - for cnt in contours: - area = float(cv2.contourArea(cnt)) - if area > max_area: - max_area = area - max_cnt = cnt - - if max_cnt is None or max_area <= 0: - return [], 0.0 - - epsilon = epsilon_factor * cv2.arcLength(max_cnt, True) - approx = cv2.approxPolyDP(max_cnt, epsilon, True) - if approx is None or len(approx) < 3: - return [], 0.0 - return approx.reshape(-1, 2).tolist(), max_area - - -def _calculate_area(bbox: List[int]) -> int: - x1, y1, x2, y2 = bbox - return max(0, x2 - x1) * max(0, y2 - y1) - - -class Sam3Runtime: - def __init__( - self, - checkpoint_path: str, - bpe_path: Optional[str] = None, - score_threshold: float = 0.5, - epsilon_factor: float = 0.02, - min_area: int = 100, - device: str = "cuda", - cache_size: int = 2, - ) -> None: - self.score_threshold = score_threshold - self.epsilon_factor = epsilon_factor - self.min_area = min_area - - self.model = build_sam3_image_model( - bpe_path=bpe_path, - checkpoint_path=checkpoint_path, - load_from_HF=False, - device=device, - ) - self.processor = Sam3Processor(self.model, device=device) - - self.cache_size = cache_size - self.state_cache: OrderedDict[str, Dict] = OrderedDict() - self.cache_lock = asyncio.Lock() - self.inference_lock = asyncio.Lock() - - async def _get_image_state(self, image_path: str) -> Dict: - async with self.cache_lock: - if image_path in self.state_cache: - self.state_cache.move_to_end(image_path) - return self.state_cache[image_path] - - if not os.path.exists(image_path): - raise FileNotFoundError(f"Image not found: {image_path}") - pil_image = Image.open(image_path).convert("RGB") - canvas_size = pil_image.size - - image_state = self.processor.set_image(pil_image) - cache_item = { - "image_state": image_state, - "canvas_size": canvas_size, - } - - async with self.cache_lock: - self.state_cache[image_path] = cache_item - if len(self.state_cache) > self.cache_size: - self.state_cache.popitem(last=False) - return cache_item - - def _build_detection( - self, - prompt: str, - score: float, - bbox: List[int], - polygon: List[List[int]], - mask_payload: Optional[str], - mask_format: Optional[str], - mask_shape: Optional[List[int]], - ) -> Dict: - item: Dict = { - "prompt": prompt, - "score": score, - "bbox": bbox, - "polygon": polygon, - "area": _calculate_area(bbox), - } - if mask_payload is not None and mask_format is not None and mask_shape is not None: - item["mask"] = { - "data": mask_payload, - "format": mask_format, - "shape": mask_shape, - } - return item - - async def predict(self, payload: PredictRequest) -> PredictResponse: - async with self.inference_lock: - cache_item = await self._get_image_state(payload.image_path) - state = cache_item["image_state"] - canvas_w, canvas_h = cache_item["canvas_size"] - - score_threshold = payload.score_threshold or self.score_threshold - epsilon_factor = payload.epsilon_factor or self.epsilon_factor - min_area = payload.min_area or self.min_area - - all_results: List[Dict] = [] - - for prompt in payload.prompts: - self.processor.reset_all_prompts(state) - result_state = self.processor.set_text_prompt(prompt=prompt, state=state) - masks = result_state.get("masks", []) - boxes = result_state.get("boxes", []) - scores = result_state.get("scores", []) - - if masks is None or len(masks) == 0: - continue - - num_masks = masks.shape[0] if isinstance(masks, torch.Tensor) else len(masks) - for i in range(num_masks): - score_val = scores[i] - score_val = score_val.item() if hasattr(score_val, "item") else float(score_val) - if score_val < score_threshold: - continue - - box = boxes[i] - bbox = box.detach().cpu().numpy().tolist() if isinstance(box, torch.Tensor) else box - bbox = [int(v) for v in bbox] - - mask = masks[i] - binary_mask = mask.detach().cpu().numpy() if isinstance(mask, torch.Tensor) else np.array(mask) - if binary_mask.ndim > 2: - binary_mask = binary_mask.squeeze() - binary_mask = (binary_mask > 0.5).astype(np.uint8) * 255 - - polygon, polygon_area = _extract_polygon(binary_mask, epsilon_factor) - if len(polygon) == 0 or polygon_area < min_area: - continue - - mask_payload = None - mask_shape = None - if payload.return_masks: - mask_shape = [binary_mask.shape[0], binary_mask.shape[1]] - if payload.mask_format == "png": - mask_payload = _encode_mask_png(binary_mask) - else: - mask_payload = _encode_mask_rle(binary_mask) - - all_results.append( - self._build_detection( - prompt=prompt, - score=score_val, - bbox=bbox, - polygon=polygon, - mask_payload=mask_payload, - mask_format=payload.mask_format if payload.return_masks else None, - mask_shape=mask_shape, - ) - ) - - return PredictResponse( - image_size={"width": canvas_w, "height": canvas_h}, - results=all_results, - ) - - -def create_app(runtime: Sam3Runtime) -> FastAPI: - app = FastAPI(title="SAM3 Model Server", version="1.0.0") - - @app.get("/health") - async def health() -> Dict[str, str]: - return {"status": "ok"} - - @app.post("/predict", response_model=PredictResponse) - async def predict(request: PredictRequest) -> PredictResponse: - try: - return await runtime.predict(request) - except FileNotFoundError as exc: - raise HTTPException(status_code=404, detail=str(exc)) from exc - except Exception as exc: - raise HTTPException(status_code=500, detail=str(exc)) from exc - - return app - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Start a persistent SAM3 HTTP service") - parser.add_argument("--host", default="0.0.0.0", help="Host to bind") - parser.add_argument("--port", type=int, default=8001, help="Port to bind") - parser.add_argument( - "--checkpoint", - default=os.getenv("SAM3_CHECKPOINT_PATH", "/data/users/pzw/models/sam3/sam3.pt"), - help="Path to SAM3 checkpoint", - ) - parser.add_argument( - "--bpe", - default=os.getenv("SAM3_BPE_PATH", "/data/users/pzw/models/sam3/bpe_simple_vocab_16e6.txt.gz"), - help="Path to SAM3 BPE file", - ) - parser.add_argument("--score-threshold", type=float, default=0.5) - parser.add_argument("--epsilon-factor", type=float, default=0.02) - parser.add_argument("--min-area", type=int, default=100) - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device id") - parser.add_argument("--cache-size", type=int, default=2, help="LRU cache size for encoded images") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - runtime = Sam3Runtime( - checkpoint_path=args.checkpoint, - bpe_path=args.bpe, - score_threshold=args.score_threshold, - epsilon_factor=args.epsilon_factor, - min_area=args.min_area, - device=args.device, - cache_size=args.cache_size, - ) - uvicorn.run(create_app(runtime), host=args.host, port=args.port, workers=1) diff --git a/dataflow_agent/toolkits/model_servers/sam_server.py b/dataflow_agent/toolkits/model_servers/sam_server.py deleted file mode 100644 index 1c83fe1..0000000 --- a/dataflow_agent/toolkits/model_servers/sam_server.py +++ /dev/null @@ -1,121 +0,0 @@ -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from typing import Optional, List, Any, Dict -import numpy as np -import base64 -import zlib -import os -import sys - -# Add project root to path to ensure imports work -current_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.abspath(os.path.join(current_dir, "../../../")) -if project_root not in sys.path: - sys.path.insert(0, project_root) - -from dataflow_agent.toolkits.multimodaltool.sam_tool import run_sam_auto, free_sam_model - -try: - import torch -except ImportError: - torch = None - -app = FastAPI(title="SAM Model Server") - -# Check CUDA device on startup -@app.on_event("startup") -async def startup_event(): - print("SAM Server Startup Check:") - print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not Set')}") - if torch and torch.cuda.is_available(): - print(f"Torch CUDA available: {torch.cuda.is_available()}") - print(f"Current Device Count: {torch.cuda.device_count()}") - print(f"Current Device Name: {torch.cuda.get_device_name(0)}") - else: - print("CUDA NOT AVAILABLE") - -class SAMRequest(BaseModel): - image_path: str - checkpoint: str = "sam_b.pt" - device: str = "cuda" - -class SAMItemResponse(BaseModel): - mask_b64: str - mask_shape: List[int] - bbox: List[float] - score: Optional[float] = None - area: int - -class SAMResponse(BaseModel): - items: List[SAMItemResponse] - -@app.post("/predict", response_model=SAMResponse) -def predict(req: SAMRequest): - """ - Run SAM auto segmentation on the given image path. - """ - if not os.path.exists(req.image_path): - raise HTTPException(status_code=404, detail=f"Image path not found: {req.image_path}") - - try: - # Use the device from request, CUDA_VISIBLE_DEVICES will handle GPU mapping - target_device = req.device - - # 调用本地的 sam_tool 函数 - # 注意:这里会利用 sam_tool 内部的 caching 机制 - # 如果启动了多个 sam_server 进程,每个进程会维护自己的 cache - items = run_sam_auto( - image_path=req.image_path, - checkpoint=req.checkpoint, - device=target_device - ) - - # 序列化结果 - serialized_items = [] - for it in items: - mask = it.get("mask") - if mask is None: - continue - - # Convert mask to base64 - # mask is numpy array (bool or uint8) - if not isinstance(mask, np.ndarray): - mask = np.array(mask) - - # Use bool type for serialization consistency - mask_bool = mask.astype(bool) - mask_bytes = mask_bool.tobytes() - # Compress using zlib to reduce payload size - compressed_bytes = zlib.compress(mask_bytes) - mask_b64 = base64.b64encode(compressed_bytes).decode('utf-8') - - serialized_items.append(SAMItemResponse( - mask_b64=mask_b64, - mask_shape=list(mask.shape), - bbox=it.get("bbox", []), - score=it.get("score"), - area=it.get("area", 0) - )) - - return SAMResponse(items=serialized_items) - - except Exception as e: - import traceback - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(e)) - finally: - # Aggressive cleanup to prevent OOM - if torch and torch.cuda.is_available(): - torch.cuda.empty_cache() - -@app.post("/free_model") -def free_model(checkpoint: str = "sam_b.pt"): - try: - free_sam_model(checkpoint) - return {"status": "ok", "message": f"Model {checkpoint} freed"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app.get("/health") -def health(): - return {"status": "ok"} diff --git a/dataflow_agent/toolkits/multimodaltool/__init__.py b/dataflow_agent/toolkits/multimodaltool/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/optool/__init__.py b/dataflow_agent/toolkits/optool/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/optool/op_tools.py b/dataflow_agent/toolkits/optool/op_tools.py deleted file mode 100644 index 3484f4c..0000000 --- a/dataflow_agent/toolkits/optool/op_tools.py +++ /dev/null @@ -1,958 +0,0 @@ -from __future__ import annotations -import asyncio -import inspect -import sys -import os -from pydantic import BaseModel -import httpx -import json -import uuid -from typing import List, Dict, Sequence, Any, Union, Optional, Iterable, Mapping, Set, Callable -from pathlib import Path - -from functools import lru_cache -import yaml -# from clickhouse_connect import get_client -import subprocess -from collections import defaultdict, deque -from dataflow.utils.storage import FileStorage -# from dataflow_agent.logger import get_logger -# logger = get_logger() -from dataflow_agent.storage.storage_service import SampleFileStorage -from dataflow_agent.state import DFState,DFRequest - -import inspect -import json -import os -from pathlib import Path -from typing import Any, Dict, List, Tuple - -from dataflow.utils.registry import OPERATOR_REGISTRY -from langchain_core.tools import tool -from dataflow_agent.logger import get_logger - -log = get_logger(__name__) -RESOURCE_DIR = Path(__file__).resolve().parent.parent / "resources" -OPS_JSON_PATH = RESOURCE_DIR / "ops.json" - -def local_tool_for_get_purpose(req: DFRequest) -> str: - return req.target or "" - -# ===================================================================更新算子库部分代码: -def _safe_json_val(val: Any) -> Any: - """ - 把任意 Python 对象转换成 JSON 可序列化的值。 - 规则: - 1. 基本类型(None / bool / int / float / str)直接返回; - 2. enum/类对象 → 返回 'module.qualname'; - 3. 其它复杂对象 → 返回 str(val); - """ - # 空值直接交给 _param_to_dict 去处理 - if val is inspect.Parameter.empty: - return None - - # 基本可 JSON 类型 - if isinstance(val, (str, int, float, bool)) or val is None: - return val - - # 类、函数、枚举等 → module.qualname - if isinstance(val, type): - return f"{val.__module__}.{val.__qualname__}" - - # Python3.10+ 的 A | B 产生的 UnionType - if getattr(val, "__origin__", None) is None and val.__class__.__name__ == "UnionType": - return str(val) # e.g. "A | B | C" - - # 尝试直接 dump - try: - json.dumps(val) - return val - except TypeError: - return str(val) - -# 工具函数:安全调用带 @staticmethod 的 get_desc(lang) -def _call_get_desc_static(cls, lang: str = "zh") -> str | None: - """ - 仅当类的 get_desc 被显式声明为 @staticmethod 时才调用。 - 兼容两种签名: (lang) 或 (self, lang)。 - 返回 None 表示跳过此算子。 - """ - func_obj = cls.__dict__.get("get_desc") - if not isinstance(func_obj, staticmethod): - return None - - fn = func_obj.__func__ - params = list(inspect.signature(fn).parameters) - try: - if params == ["lang"]: - return fn(lang) - if params == ["self", "lang"]: - return fn(None, lang) - except Exception as e: - log.warning(f"调用 {cls.__name__}.get_desc 失败: {e}") - return None - - -# --------------------------------------------------------------------------- -def _param_to_dict(p: inspect.Parameter) -> Dict[str, Any]: - """把 inspect.Parameter 转成 JSON 可序列化的字典(参考 MCP func 定义)""" - return { - "name": p.name, - # "default": None if p.default is inspect.Parameter.empty else p.default, - "default": _safe_json_val(p.default), - "kind": p.kind.name, # POSITIONAL_OR_KEYWORD / VAR_POSITIONAL / ... - } - - -def _get_method_params( - method: Any, skip_first_self: bool = False -) -> List[Dict[str, Any]]: - """ - 提取方法形参,转换为列表。 - skip_first_self=True 时会丢掉第一个 self 参数。 - """ - try: - sig = inspect.signature(method) - params = list(sig.parameters.values()) - if skip_first_self and params and params[0].name == "self": - params = params[1:] - return [_param_to_dict(p) for p in params] - except Exception as e: - log.warning(f"获取方法参数出错: {e}") - return [] - - -def _gather_single_operator( - op_name: str, cls: type, node_index: int -) -> Tuple[str, Dict[str, Any]]: - """ - 收集单个算子的全部信息,返回 (category, info_dict) - """ - # 1) 分类:dataflow.operators..xxx - category = "unknown" - if hasattr(cls, "__module__"): - parts = cls.__module__.split(".") - if len(parts) >= 3 and parts[0] == "dataflow" and parts[1] == "operators": - category = parts[2] - - # 2) 描述 - description = _call_get_desc_static(cls, lang="zh") or "" - - # 3) command 形参 - init_params = _get_method_params(cls.__init__, skip_first_self=True) - run_params = _get_method_params(getattr(cls, "run", None), skip_first_self=True) - - info = { - "node": node_index, - "name": op_name, - "description": description, - "parameter": { - "init": init_params, - "run": run_params, - }, - # 下面三项暂时留空,后续有需要再填 - "required": "", - "depends_on": [], - "mode": "", - } - return category, info - - -def _dump_all_ops_to_file() -> Dict[str, List[Dict[str, Any]]]: - """ - 遍历 OPERATOR_REGISTRY,构建完整字典并写入 ops.json。 - 额外添加 "Default" → 所有算子全集。 - """ - log.info("开始扫描 OPERATOR_REGISTRY,生成 ops.json ...") - - if hasattr(OPERATOR_REGISTRY, "_init_loaders"): - OPERATOR_REGISTRY._init_loaders() - if hasattr(OPERATOR_REGISTRY, "_get_all"): - OPERATOR_REGISTRY._get_all() - - all_ops: Dict[str, List[Dict[str, Any]]] = {} - default_bucket: List[Dict[str, Any]] = [] - - idx = 1 - for op_name, cls in OPERATOR_REGISTRY: - category, info = _gather_single_operator(op_name, cls, idx) - all_ops.setdefault(category, []).append(info) - default_bucket.append(info) - idx += 1 - - all_ops["Default"] = default_bucket - - RESOURCE_DIR.mkdir(parents=True, exist_ok=True) - try: - with open(OPS_JSON_PATH, "w", encoding="utf-8") as f: - json.dump(all_ops, f, ensure_ascii=False, indent=2) - log.info(f"算子信息已写入 {OPS_JSON_PATH}") - except Exception as e: - log.warning(f"写入 {OPS_JSON_PATH} 失败: {e}") - - return all_ops - -def _ensure_ops_cache() -> Dict[str, List[Dict[str, Any]]]: - """ - 若 ops.json 不存在或为空,则重新生成。 - 返回文件中的全部数据。 - """ - if OPS_JSON_PATH.exists(): - try: - with open(OPS_JSON_PATH, "r", encoding="utf-8") as f: - data = json.load(f) - if data: # 非空文件 - return data - except Exception as e: - log.warning(f"读取 {OPS_JSON_PATH} 失败,将重新生成: {e}") - return _dump_all_ops_to_file() - - -# 供 LangChain Tool 调用的主函数 -def get_operator_content(data_type: str) -> str: - """ - 根据传入的 `data_type`(即算子类别,如 "text2sql", "rag" …) - 返回该类别下所有算子的 JSON 字符串。 - - 如果该类别不存在,返回 "[]" - """ - # all_ops = _ensure_ops_cache() - all_ops = _dump_all_ops_to_file() - - import copy - - if data_type in all_ops: - content = copy.deepcopy(all_ops[data_type]) - else: - content = [] - - # 作为字符串返回,方便 LLM 直接嵌入提示词 - return json.dumps(content, ensure_ascii=False, indent=2) - - -def get_operator_content_str(data_type: str) -> str: - """ - 返回该类别下所有算子的 “name:描述” 长字符串,用分号分隔。 - """ - all_ops = _dump_all_ops_to_file() # 或 _ensure_ops_cache() - raw_items = all_ops.get(data_type, []) - - # 用英文引号,如果有需要可用中文引号 - lines = [ - f'"{item.get("name", "")}":"{item.get("description", "")}"' - for item in raw_items - ] - return "\n".join(lines) - -def get_prompt_sources_of_operator(op_name: str) -> Dict[str, str]: - """ - 获取 operator 的 prompt_templates 的源码,并随机获取2个示例 - """ - import random - cls = OPERATOR_REGISTRY.get(op_name) - if cls is None: - raise KeyError(f"Operator {op_name} not found in registry") - log.info(f"Getting prompt_sources of {op_name}") - - # 获取 prompt_templates,如果没有则抛出异常 - if getattr(cls, "ALLOWED_PROMPTS", None): - prompt_classes = cls.ALLOWED_PROMPTS - else: - raise ValueError(f"Operator {op_name} has no ALLOWED_PROMPTS") - - # 如果 prompt_templates 为空,则抛出异常,若只有一个,则直接使用,否则随机采样2个示例 - if len(prompt_classes) == 0: - raise ValueError(f"Operator {op_name} has no prompt_templates") - if len(prompt_classes) == 1: - sample_classes = prompt_classes - else: - sample_classes = random.sample(prompt_classes, 2) - - # 获取源码 - out = {} - for c in sample_classes: - try: - out[c.__name__] = inspect.getsource(c) - except OSError: - out[c.__name__] = "# 源码不可用(可能是C扩展/找不到源码/zip导入)" - return out - -def get_operators_info_by_names(operator_names: List[str]) -> str: - """ - 根据算子名称列表获取基本信息(node, name, description, category)。 - - Args: - operator_names: 算子名称列表,如 ['ExtractSmilesFromText', 'LLMLanguageFilter', ...] - - Returns: - 包含所有指定算子基本信息的JSON字符串。 - 如果某个算子不存在,会在结果中标注 "error" 字段。 - """ - # 初始化 OPERATOR_REGISTRY - if hasattr(OPERATOR_REGISTRY, "_init_loaders"): - OPERATOR_REGISTRY._init_loaders() - if hasattr(OPERATOR_REGISTRY, "_get_all"): - OPERATOR_REGISTRY._get_all() - - # 构建名称到类的映射 - name_to_cls = {name: cls for name, cls in OPERATOR_REGISTRY} - - # 收集结果 - results = [] - idx = 1 - - for op_name in operator_names: - cls = name_to_cls.get(op_name) - if cls is None: - # 算子不存在 - results.append({ - "node": idx, - "name": op_name, - "error": f"算子 '{op_name}' 未在 OPERATOR_REGISTRY 中注册" - }) - else: - # 获取分类 - category = "unknown" - if hasattr(cls, "__module__"): - parts = cls.__module__.split(".") - if len(parts) >= 3 and parts[0] == "dataflow" and parts[1] == "operators": - category = parts[2] - - # 获取描述 - description = _call_get_desc_static(cls, lang="zh") or "" - - # 只返回基本信息 - results.append({ - "node": idx, - "name": op_name, - "description": description, - "category": category - }) - idx += 1 - - # 返回 JSON 字符串 - return json.dumps(results, ensure_ascii=False, indent=2) - -def get_operator_source_by_name(operator_name: str) -> str: - """ - 根据算子名称获取算子的源码。 - 参数: - operator_name: 算子名称(注册在 OPERATOR_REGISTRY 中) - 返回: - 源码字符串或错误提示信息 - """ - try: - # 初始化 OPERATOR_REGISTRY(如果需要) - if hasattr(OPERATOR_REGISTRY, "_init_loaders"): - OPERATOR_REGISTRY._init_loaders() - if hasattr(OPERATOR_REGISTRY, "_get_all"): - OPERATOR_REGISTRY._get_all() - - # 遍历注册的算子,找到匹配的名称 - for name, cls in OPERATOR_REGISTRY: - if name == operator_name: - # 获取源码 - try: - source_code = inspect.getsource(cls) - return source_code - except Exception as e: - return f"# 无法获取源码: {e}" - - # 如果未找到对应的算子名称 - return f"# 未找到算子 '{operator_name}',请检查名称是否正确。" - - except Exception as e: - return f"# 获取算子源码时发生错误: {e}" - -def get_prompt_sources_of_operator(op_name: str) -> Dict[str, str]: - """ - 获取 operator 的 prompt_templates 的源码,并随机获取2个示例 - """ - import random - cls = OPERATOR_REGISTRY.get(op_name) - if cls is None: - raise KeyError(f"Operator {op_name} not found in registry") - log.info(f"Getting prompt_sources of {op_name}") - - # 获取 prompt_templates,如果没有则抛出异常 - if getattr(cls, "ALLOWED_PROMPTS", None): - prompt_classes = cls.ALLOWED_PROMPTS - else: - raise ValueError(f"Operator {op_name} has no ALLOWED_PROMPTS") - - # 如果 prompt_templates 为空,则抛出异常,若只有一个,则直接使用,否则随机采样2个示例 - if len(prompt_classes) == 0: - raise ValueError(f"Operator {op_name} has no prompt_templates") - if len(prompt_classes) == 1: - sample_classes = prompt_classes - else: - sample_classes = random.sample(prompt_classes, 2) - - # 获取源码 - out = {} - for c in sample_classes: - try: - out[c.__name__] = inspect.getsource(c) - except OSError: - out[c.__name__] = "# 源码不可用(可能是C扩展/找不到源码/zip导入)" - return out - -def post_process_combine_pipeline_result(results: Dict) -> str: - - return "hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh" - - -# if __name__ == "__main__": -# log.info(get_operator_content("text2sql")) - - -# =================================================================== 算子RAG部分代码: -import os -import json -import pickle -import httpx -import numpy as np -import faiss -from typing import List, Dict, Union, Optional - -import dataflow_agent.utils as utils - -def _call_openai_embedding_api( - texts: List[str], - model_name: str = "text-embedding-ada-002", - base_url: str = "https://api.openai.com/v1/embeddings", - api_key: str | None = None, - timeout: float = 120.0, -) -> np.ndarray: - """调用OpenAI API获取文本向量""" - if api_key is None: - api_key = os.getenv("DF_API_KEY") - if not api_key: - raise RuntimeError("必须提供 OpenAI API-Key,可通过参数或环境变量 DF_API_KEY") - - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - } - - vecs: List[List[float]] = [] - with httpx.Client(timeout=timeout) as client: - for t in texts: - resp = client.post( - base_url, - headers=headers, - json={"model": model_name, "input": t}, - ) - try: - resp.raise_for_status() - except httpx.HTTPStatusError as e: - raise RuntimeError(f"调用 OpenAI embedding 失败: {e}\n{resp.text}") from e - - try: - data = resp.json() - vec = data["data"][0]["embedding"] - except Exception as e: - raise RuntimeError(f"解析返回 JSON 失败: {resp.text}") from e - - vecs.append(vec) - - arr = np.asarray(vecs, dtype=np.float32) - faiss.normalize_L2(arr) - return arr - - -class RAGOperatorSearch: - """RAG 算子检索类,支持向量持久化和批量查询""" - - def __init__( - self, - ops_json_path: str, - category: Optional[str] = None, - faiss_index_path: Optional[str] = None, - model_name: str = "text-embedding-ada-002", - base_url: str = "https://api.openai.com/v1/embeddings", - api_key: Optional[str] = None, - ): - """ - 初始化 RAG 检索器 - - Args: - ops_json_path: 算子JSON文件路径 - category: 算子类别,如果为None则读取全部 - faiss_index_path: FAISS索引文件路径,如果存在则复用,否则生成并保存 - model_name: embedding模型名称 - base_url: API base URL - api_key: OpenAI API key - """ - self.ops_json_path = ops_json_path - self.category = category - self.faiss_index_path = faiss_index_path - self.model_name = model_name - self.base_url = base_url - self.api_key = api_key - - self.index = None - self.ops_list = [] - - self._load_or_build_index() - - def _load_operators(self) -> List[Dict]: - """加载算子数据""" - with open(self.ops_json_path, "r", encoding="utf-8") as f: - all_ops = json.load(f) - - if self.category: - # 指定类别 - ops = all_ops.get(self.category, []) - log.info(f"✓ 加载类别 '{self.category}' 的算子: {len(ops)} 个") - else: - # 读取全部类别 - 直接使用 "Default" 避免重复加载 - ops = all_ops.get("Default", []) - log.info(f"✓ 加载全部算子: {len(ops)} 个") - - return ops - - def _load_or_build_index(self): - """加载或构建FAISS索引""" - # 检查是否可以复用索引 - if self.faiss_index_path and os.path.exists(self.faiss_index_path): - meta_path = self.faiss_index_path + ".meta" - if os.path.exists(meta_path): - log.info(f"✓ 从 {self.faiss_index_path} 加载已有索引...") - self.index = faiss.read_index(self.faiss_index_path) - with open(meta_path, "rb") as f: - self.ops_list = pickle.load(f) - log.info(f"✓ 索引加载成功,包含 {len(self.ops_list)} 个算子") - return - - # 先用最新的 OPERATOR_REGISTRY 刷新 ops.json 快照 - log.info("⚙ 正在刷新 ops.json 算子快照...") - _dump_all_ops_to_file() - - # 重新构建索引 - log.info("⚙ 开始构建新的向量索引...") - self.ops_list = self._load_operators() - - if not self.ops_list: - raise ValueError("没有找到任何算子数据!") - - # 生成文本描述 - texts = [f"{op['name']} {op.get('description', '')}" for op in self.ops_list] - - # 调用API获取向量 - log.info(f"⚙ 正在获取 {len(texts)} 个算子的 embedding...") - embeddings = _call_openai_embedding_api( - texts, - model_name=self.model_name, - base_url=self.base_url, - api_key=self.api_key, - ) - - # 构建FAISS索引 - dim = embeddings.shape[1] - self.index = faiss.IndexFlatIP(dim) - self.index.add(embeddings) - log.info(f"✓ 索引构建完成,维度: {dim}") - - # 保存索引(如果指定了路径) - if self.faiss_index_path: - # 确保目录存在 - os.makedirs(os.path.dirname(self.faiss_index_path) or ".", exist_ok=True) - log.info(f"⚙ 保存索引到 {self.faiss_index_path}...") - faiss.write_index(self.index, self.faiss_index_path) - with open(self.faiss_index_path + ".meta", "wb") as f: - pickle.dump(self.ops_list, f) - log.info("✓ 索引保存成功") - - def search( - self, - queries: Union[str, List[str]], - top_k: int = 5, - return_scores: bool = False - ) -> Union[List[str], List[List[str]], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: - """ - 检索最相关的算子 - - Args: - queries: 单个查询字符串或查询列表 - top_k: 返回top-k个结果 - return_scores: 是否返回相似度分数 - - Returns: - 如果 return_scores=False: - 如果输入是字符串,返回List[str] - 如果输入是列表,返回List[List[str]] - 如果 return_scores=True: - 如果输入是字符串,返回List[Dict],每个Dict包含 name, description, similarity_score - 如果输入是列表,返回List[List[Dict]] - """ - # 统一处理为列表 - is_single = isinstance(queries, str) - if is_single: - queries = [queries] - - # 批量获取query向量 - query_vecs = _call_openai_embedding_api( - queries, - model_name=self.model_name, - base_url=self.base_url, - api_key=self.api_key, - ) - - # 检索 - D, I = self.index.search(query_vecs, top_k) - - # 组织结果 - results = [] - for i, (indices, scores) in enumerate(zip(I, D)): - if return_scores: - # 返回包含分数的详细信息 - matched_ops = [] - for idx, score in zip(indices, scores): - op_info = self.ops_list[idx] - matched_ops.append({ - "name": op_info["name"], - "description": op_info.get("description", ""), - "similarity_score": float(score) # FAISS cosine similarity score - }) - results.append(matched_ops) - log.info(f"Query {i+1}: '{queries[i][:50]}...' -> {[(op['name'], round(op['similarity_score'], 3)) for op in matched_ops]}") - else: - # 原有逻辑,只返回名称列表 - matched_ops = [self.ops_list[idx]["name"] for idx in indices] - results.append(matched_ops) - log.info(f"Query {i+1}: '{queries[i][:50]}...' -> {matched_ops}") - - # 如果是单查询,返回单个列表 - return results[0] if is_single else results - - -def get_operators_by_rag( - search_queries: Union[str, List[str]], - category: Optional[str] = None, - top_k: int = 4, - ops_json_path: str = utils.get_project_root() / "dataflow_agent/toolkits/resources/ops.json", - faiss_index_path: Optional[str] = None, - model_name: str = "text-embedding-3-small", - base_url: str = "http://123.129.219.111:3000/v1/embeddings", - api_key: str = os.getenv("DF_API_KEY"), -) -> Union[List[str], List[List[str]]]: - """ - 通过RAG检索算子 - - Args: - search_queries: 单个查询字符串 或 查询列表 ['xxx1', 'xxx2'] - category: 算子类别,None表示读取全部 - top_k: 每个查询返回top-k结果 - ops_json_path: 算子JSON文件路径 - faiss_index_path: FAISS索引文件路径,如果存在则复用,否则重新生成 - model_name: embedding模型 - base_url: API地址 - api_key: API密钥 - - Returns: - 单查询返回List[str],多查询返回List[List[str]] - - Examples: - # 单个查询 - result = get_operators_by_rag("将自然语言转换为SQL") - # 返回: ['op1', 'op2', 'op3', 'op4'] - - # 批量查询 - results = get_operators_by_rag(['query1', 'query2']) - # 返回: [['op1', 'op2'], ['op3', 'op4']] - """ - searcher = RAGOperatorSearch( - ops_json_path=ops_json_path, - category=category, - faiss_index_path=faiss_index_path, - model_name=model_name, - base_url=base_url, - api_key=api_key, - ) - - return searcher.search(search_queries, top_k=top_k) - - -def local_tool_for_get_match_operator_code(pre_task_result): - import time - import sys - import inspect - from dataflow.utils.registry import OPERATOR_REGISTRY - - start_time = time.time() - if not pre_task_result or not isinstance(pre_task_result, dict): - return "# ❗ pre_task_result is empty, cannot extract operator names" - - _NAME2CLS = {name: cls for name, cls in OPERATOR_REGISTRY} - - blocks = [] - for op_name in pre_task_result.get("match_operators", [])[:2]: - cls = _NAME2CLS.get(op_name) - if cls is None: - blocks.append(f"# --- {op_name} is not registered in OPERATOR_REGISTRY ---") - continue - try: - cls_src = inspect.getsource(cls) - module_src = inspect.getsource(sys.modules[cls.__module__]) - import_lines = [ - l for l in module_src.splitlines() - if l.strip().startswith(("import ", "from ")) - ] - import_block = "\n".join(import_lines) - src_block = f"# === Source of {op_name} ===\n{import_block}\n\n{cls_src}" - blocks.append(src_block) - except (OSError, TypeError) as e: - blocks.append(f"# --- Failed to get the source code of {op_name}: {e} ---") - - elapsed = time.time() - start_time - log.info(f"[local_tool_for_get_match_operator_code] Time used: {elapsed:.4f} seconds") - return "\n\n".join(blocks) - - -# =================================================================== LangChain Tool 封装的 RAG 工具: - -# 匹配质量阈值定义 -MATCH_QUALITY_THRESHOLDS = { - "high": 0.5, # >= 0.5 为高度匹配 - "medium": 0.3, # >= 0.3 为中等匹配 - # < 0.3 为低匹配 -} - - -# 默认 FAISS 索引缓存路径 -DEFAULT_FAISS_INDEX_PATH = str(utils.get_project_root() / "dataflow_agent/resources/faiss_cache/all_ops.index") - - -def _get_operators_by_rag_with_scores( - search_query: str, - top_k: int = 4, - ops_json_path: str = None, - faiss_index_path: str = None, - model_name: str = "text-embedding-3-small", - base_url: str = "http://123.129.219.111:3000/v1/embeddings", - api_key: str = None, -) -> List[Dict[str, Any]]: - """ - 通过RAG检索算子,返回包含相似度分数的详细结果 - - Args: - search_query: 搜索查询 - top_k: 返回top-k结果 - ops_json_path: 算子JSON文件路径 - faiss_index_path: FAISS索引文件路径,如果存在则复用,否则生成并保存 - model_name: embedding模型 - base_url: API地址 - api_key: API密钥 - - Returns: - List[Dict],每个Dict包含 name, description, similarity_score - """ - if ops_json_path is None: - ops_json_path = utils.get_project_root() / "dataflow_agent/toolkits/resources/ops.json" - if faiss_index_path is None: - faiss_index_path = DEFAULT_FAISS_INDEX_PATH - if api_key is None: - api_key = os.getenv("DF_API_KEY") - - searcher = RAGOperatorSearch( - ops_json_path=str(ops_json_path), - category=None, - faiss_index_path=faiss_index_path, - model_name=model_name, - base_url=base_url, - api_key=api_key, - ) - - return searcher.search(search_query, top_k=top_k, return_scores=True) - - -def _determine_match_quality(max_score: float) -> str: - """根据最高相似度分数判断匹配质量""" - if max_score >= MATCH_QUALITY_THRESHOLDS["high"]: - return "high" - elif max_score >= MATCH_QUALITY_THRESHOLDS["medium"]: - return "medium" - else: - return "low" - - -def _generate_match_warning(query: str, max_score: float, match_quality: str) -> Optional[str]: - """根据匹配质量生成警告信息""" - if match_quality == "high": - return None - elif match_quality == "medium": - return ( - f"提示:与'{query}'相关的算子匹配度为中等(最高相似度: {max_score:.3f})。" - f"请仔细阅读算子描述,确认是否满足您的需求。" - ) - else: # low - return ( - f"警告:未找到与'{query}'高度匹配的算子。最高相似度仅为{max_score:.3f}," - f"低于推荐阈值{MATCH_QUALITY_THRESHOLDS['medium']}。" - f"当前返回的算子可能无法满足您的需求。如果没有合适的算子," - f"请在回复中说明'未能找到满足{query}需求的算子'。" - ) - - -@tool -def search_operator_by_description(query: str, top_k: int = 4) -> str: - """ - 根据功能描述搜索最匹配的数据处理算子。 - - 当需要在 pipeline 中添加新算子时,必须先调用此工具搜索真实存在的算子。 - 禁止使用此工具返回结果之外的算子名称。 - - **重要**:该工具会返回匹配质量评估(match_quality): - - "high": 高度匹配(相似度>=0.5),可以放心使用 - - "medium": 中等匹配(相似度0.3-0.5),请仔细确认是否满足需求 - - "low": 低匹配(相似度<0.3),可能无法满足需求,请考虑说明"未能找到满足需求的算子" - - Args: - query: 算子功能描述,例如 "情感分析"、"数据清洗"、"文本分类"、"去重"、"数据增强" 等 - top_k: 返回的候选算子数量,默认为4 - - Returns: - JSON 格式的搜索结果,包含匹配的算子名称、描述、相似度分数和匹配质量评估 - - Examples: - >>> search_operator_by_description("情感分析") - >>> search_operator_by_description("数据去重", top_k=3) - """ - try: - # 调用 RAG 检索(返回包含分数的详细结果) - matched_operators = _get_operators_by_rag_with_scores(query, top_k=top_k) - - # 计算最高相似度分数 - max_score = 0.0 - if matched_operators: - max_score = max(op.get("similarity_score", 0.0) for op in matched_operators) - - # 判断匹配质量 - match_quality = _determine_match_quality(max_score) - - # 生成警告信息 - warning = _generate_match_warning(query, max_score, match_quality) - - # 构建返回结果 - result = { - "query": query, - "matched_operators": matched_operators, - "max_similarity_score": round(max_score, 4), - "match_quality": match_quality, - } - - # 添加警告信息(如果有) - if warning: - result["warning"] = warning - - # 根据匹配质量生成不同的指导说明 - if match_quality == "high": - result["instruction"] = ( - "请从 matched_operators 中选择最合适的算子名称(name字段)。" - "匹配质量高,可以放心使用。" - ) - elif match_quality == "medium": - result["instruction"] = ( - "请从 matched_operators 中选择最合适的算子名称(name字段)。" - "注意:匹配质量为中等,请仔细阅读算子描述(description)确认是否满足需求。" - ) - else: # low - result["instruction"] = ( - "注意:当前匹配质量较低!请仔细评估 matched_operators 中的算子是否能满足需求。" - f"如果没有合适的算子,请在回复中明确说明'未能找到满足「{query}」需求的算子'," - "并给出建议(如:建议用户自定义算子,或使用其他方式实现该功能)。" - ) - - log.info( - f"[search_operator_by_description] 查询: '{query}' -> " - f"匹配到 {len(matched_operators)} 个算子, " - f"最高相似度: {max_score:.3f}, 匹配质量: {match_quality}" - ) - return json.dumps(result, ensure_ascii=False, indent=2) - - except Exception as e: - log.error(f"[search_operator_by_description] 搜索失败: {e}") - return json.dumps({ - "error": str(e), - "query": query, - "matched_operators": [], - "match_quality": "error" - }, ensure_ascii=False) - - -@tool -def get_operator_code_by_name(operator_name: str) -> str: - """ - 根据算子名称获取算子的源代码。 - - 在选择了要使用的算子后,可以调用此工具获取算子的源代码, - 以便了解算子的 init 参数和 run 参数的具体用法。 - - Args: - operator_name: 算子名称,必须是 search_operator_by_description 返回的算子名称 - - Returns: - 算子的源代码字符串 - """ - try: - code = get_operator_source_by_name(operator_name) - log.info(f"[get_operator_code_by_name] 获取算子 '{operator_name}' 的源代码成功") - return code - except Exception as e: - log.error(f"[get_operator_code_by_name] 获取失败: {e}") - return f"# 获取算子 '{operator_name}' 源代码失败: {e}" - - -if __name__ == "__main__": - # ============ 示例1: 单个查询 + 指定category + 持久化索引 ============ - # log.info("\n" + "="*70) - # log.info("示例1: 单个查询 + 指定category + 持久化索引") - # log.info("="*70) - # result1 = get_operators_by_rag( - # search_queries="将自然语言转换为SQL查询语句", - # category="text2sql", - # top_k=3, - # faiss_index_path="./faiss_cache/text2sql.index" # 第一次生成,后续复用 - # ) - # log.info(f"\n返回结果: {result1}\n") - - # ============ 示例2: 批量查询 + 读取全部category ============ - log.info("\n" + "="*70) - log.info("示例2: 批量查询 + 读取全部category") - log.info("="*70) - queries = [ - "数据清洗和预处理", - "文本分类任务", - "生成SQL语句" - ] - result2 = get_operators_by_rag( - search_queries=queries, - category=None, # 不指定category,读取全部 - top_k=4, - faiss_index_path="" - ) - log.info(f"\n返回结果: {result2}\n") - - # ============ 示例3: 不持久化,每次重新生成 ============ - # log.info("\n" + "="*70) - # log.info("示例3: 不持久化索引,每次重新生成") - # log.info("="*70) - # result3 = get_operators_by_rag( - # search_queries=["数据可视化", "模型训练"], - # category="text2sql", - # top_k=3, - # faiss_index_path=None # 不指定路径,不持久化 - # ) - # log.info(f"\n返回结果: {result3}\n") - - # ============ 示例4: 自定义top_k ============ - # log.info("\n" + "="*70) - # log.info("示例4: 自定义top_k=5") - # log.info("="*70) - # result4 = get_operators_by_rag( - # search_queries="数据库查询", - # top_k=5, - # faiss_index_path="./faiss_cache/all_ops.index" - # ) - # log.info(f"\n返回结果: {result4}\n") diff --git a/dataflow_agent/toolkits/p2vtool/__init__.py b/dataflow_agent/toolkits/p2vtool/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/p2vtool/p2v_tool.py b/dataflow_agent/toolkits/p2vtool/p2v_tool.py deleted file mode 100644 index 0d65ffe..0000000 --- a/dataflow_agent/toolkits/p2vtool/p2v_tool.py +++ /dev/null @@ -1,146 +0,0 @@ -from __future__ import annotations - -from dataflow_agent.logger import get_logger -import subprocess -from pathlib import Path -from typing import Dict, Any, Tuple, Optional, List -import torch - -log = get_logger(__name__) -import re - -def get_image_paths(directory_path: str) -> List[str]: - """ - 遍历指定目录及其子目录,查找所有常见的图片文件,并返回它们的路径字符串列表。 - """ - # 1. 常用图片文件扩展名列表 - image_extensions = [ - '*.png', '*.jpg', '*.jpeg', '*.gif', '*.bmp', '*.svg', '*.webp' - ] - - base_path = Path(directory_path) - if not base_path.is_dir(): - # 如果目录不存在,返回空列表并打印错误 - print(f"Error: Directory not found at {directory_path}") - return [] - - found_image_paths: List[Path] = [] - - # 2. 递归遍历目录并收集路径 - for ext in image_extensions: - # rglob(ext) 查找所有匹配该扩展名的文件,无论嵌套多深 - # extend() 将迭代器的所有元素添加到列表中 - found_image_paths.extend(base_path.rglob(ext)) - - #3. 对找到的图片路径按照文件名日期进行排序,确保顺序 - def natural_sort_key(path: Path): - file_name = path.name - numbers = re.findall(r'(\d+)', file_name) - return tuple(int(n) for n in numbers) - - found_image_paths.sort(key=natural_sort_key) - return [str(p.resolve()) for p in found_image_paths] - - -def parse_script(script_text): - ''' - 解析脚本的内容,将其分割成(prompt, cursor_prompt)两部分 - ''' - pages = script_text.strip().split("###\n") - result = [] - for page in pages: - if not page.strip(): continue - lines = page.strip().split("\n") - page_data = [] - for line in lines: - if "|" not in line: - continue - text, cursor = line.split("|", 1) - page_data.append([text.strip(), cursor.strip()]) - result.append(page_data) - return result - -def transcribe_with_whisperx(audio_path, lang="en", device="cuda" if torch.cuda.is_available() else "cpu"): - '''根据ref_audio生成对应的ref_text,从而在后续使用f5模型时,提供对齐文本,更好的提高最后audio的效果''' - import whisperx - log.info(f"transcribe_with_whisperx 使用了 device: {device}") - model = whisperx.load_model("large-v2", device=device, compute_type="float16" if device == "cuda" else "int8") - result = model.transcribe(audio_path, language=lang) - model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) - result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_path, device) - segments = result_aligned["segments"] - text = " ".join(seg["text"].strip() for seg in segments) - return text - -def inference_f5(text_prompt, save_path, ref_audio, ref_text): - from f5_tts.api import F5TTS - f5tts = F5TTS() - f5tts.infer(ref_file=ref_audio, ref_text=ref_text, gen_text=text_prompt, file_wave=save_path, seed=None,) - - -def extract_beamer_code(text_str): - match = re.search(r"(\\documentclass(?:\[[^\]]*\])?\{beamer\}.*?\\end\{document\})", text_str, re.DOTALL) - return match.group(1) if match else None - -def compile_tex(beamer_code_path: str): - tex_path = Path(beamer_code_path).resolve() - if not tex_path.exists(): - raise FileNotFoundError(f"Tex file {tex_path} does not exist.") - work_dir = tex_path.parent - try: - # 会编译.tex文件,然后创建好一个.pdf文件 - result = subprocess.run( - ["tectonic", str(tex_path)], - check=True, - capture_output=True, - text=True, - ) - code_debug_result = "\n".join([result.stdout, result.stderr]) - log.info(f"Beamer 编译成功,输出结果:{code_debug_result}") - is_beamer_warning = False - if 'warning' in code_debug_result: - is_beamer_warning = True - log.info(f"Beamer 代码存在warning,需要更加完善一下") - is_beamer_wrong = False - return is_beamer_wrong, is_beamer_warning, code_debug_result - except subprocess.CalledProcessError as e: - log.info(f"Beamer 编译失败: {e.stderr}") - is_beamer_wrong = True - is_beamer_warning = True - code_debug_result = e.stderr - return is_beamer_wrong, is_beamer_warning, code_debug_result - -def beamer_code_validator(content: str, parsed_result: Dict[str, Any]) -> Tuple[bool, Optional[str]]: - """检查tex是否是正确的""" - from tempfile import TemporaryDirectory - - # 这里的 dir 具体是什么无所谓,因为我latex code中的图像路径是绝对路径 - with TemporaryDirectory() as temp_dir_name: - temp_dir = Path(temp_dir_name) - # 在临时目录中创建 .tex 文件 - # todo: 这里可能需要修改一下,因为在临时目录下创建文件还是不太行。 - tex_path = temp_dir / "input.tex" - - raw_beamer_code = parsed_result.get("latex_code", "") - if not raw_beamer_code: - log.error(f"The content of beamer code is empty!") - return False, "The content of beamer code is empty!" - beamer_code = extract_beamer_code(raw_beamer_code) - try: - # 1. 写入内容 - tex_path.write_text(beamer_code, encoding='utf-8') - - result = subprocess.run( - ["tectonic", str(tex_path)], - check=True, - capture_output=True, - text=True, - cwd=temp_dir - ) - log.info(f"Beamer代码修改完成,没有出现error") - code_debug_result = "\n".join([result.stdout, result.stderr]) - return True, None - - except subprocess.CalledProcessError as e: - code_debug_result = f"STDOUT:\n{e.stdout}\n\nSTDERR:\n{e.stderr}" - return False, code_debug_result diff --git a/dataflow_agent/toolkits/pipetool/__init__.py b/dataflow_agent/toolkits/pipetool/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/pipetool/pipe_tools.py b/dataflow_agent/toolkits/pipetool/pipe_tools.py deleted file mode 100644 index e2e3f1c..0000000 --- a/dataflow_agent/toolkits/pipetool/pipe_tools.py +++ /dev/null @@ -1,1773 +0,0 @@ -# dataflow/dataflowagent/toolkits/pipeline_assembler.py -from __future__ import annotations - -import ast -import json -import itertools -from pathlib import Path -from typing import Any, Dict, List, Tuple, DefaultDict -import requests -from dataflow_agent.state import DFState,DFRequest -import importlib -import inspect -import re -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -from dataflow_agent.logger import get_logger -from dataflow.utils.registry import OPERATOR_REGISTRY - -log = get_logger(__name__) - -EXTRA_IMPORTS: set[str] = set() - -# "pipeline_assembler", # 核心入口:返回 {"pipe_code": ...} -# "build_pipeline_code", # 主体:组装 pipeline 代码 -# "choose_prompt_template_by_llm", # LLM智能选择 prompt 模板 -# "render_operator_blocks", # 生成 operator 初始化与调用代码 -# "group_imports", # 汇总依赖导入 -# "extract_op_params", # 提取 operator 参数 -# "choose_prompt_template", # prompt_template 兜底选择 - -def call_llm_for_selection( - system_prompt: str, - user_message: str, - api_url: str, - api_key: str, - model: str, - temperature: float = 0.3, - max_tokens: int = 100 -) -> str: - """ - 调用 LLM API 进行选择决策 - - Args: - system_prompt: 系统提示词 - user_message: 用户消息 - api_url: API 地址(OpenAI 兼容格式) - api_key: API 密钥 - model: 模型名称 - temperature: 温度参数 - max_tokens: 最大 token 数 - - Returns: - LLM 返回的文本内容 - """ - if not api_url.endswith('/chat/completions'): - if api_url.endswith('/'): - api_url = api_url + 'chat/completions' - else: - api_url = api_url + '/chat/completions' - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } - - payload = { - "model": model, - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_message} - ], - "temperature": temperature, - "max_tokens": max_tokens - } - - try: - response = requests.post(api_url, headers=headers, json=payload, timeout=30) - response.raise_for_status() - result = response.json() - - # 提取返回的内容 - content = result.get('choices', [{}])[0].get('message', {}).get('content', '').strip() - log.info(f"[pipeline_assembler] LLM selection result: {content}") - return content - - except Exception as e: - log.error(f"[pipeline_assembler] LLM API call failed: {e}") - raise - - -def extract_prompt_info(prompt_cls: type) -> Dict[str, Any]: - """ - 提取 prompt 类的详细信息,包括示例提示词 - - Args: - prompt_cls: Prompt 类对象 - - Returns: - 包含类名、模块、文档字符串和示例提示词的字典 - """ - prompt_info = { - 'class_name': prompt_cls.__qualname__, - 'module': prompt_cls.__module__, - 'docstring': (prompt_cls.__doc__ or '').strip(), - } - - # 尝试实例化并获取示例提示词 - try: - instance = prompt_cls() - - # 如果有 build_prompt 方法 - if hasattr(instance, 'build_prompt'): - sig = inspect.signature(instance.build_prompt) - params = list(sig.parameters.keys()) - - # 构造示例参数 - example_args = {} - for param in params: - if param == 'self': - continue - # 使用占位符 - example_args[param] = f"" - - try: - # 调用 build_prompt 获取完整的提示词模板 - example_prompt = instance.build_prompt(**example_args) - # 截取前 800 字符避免过长 - prompt_info['full_prompt_template'] = example_prompt[:800] - if len(example_prompt) > 800: - prompt_info['full_prompt_template'] += "\n...[truncated]" - except Exception as e: - log.warning(f"[pipeline_assembler] Failed to get example prompt for {prompt_cls.__name__}: {e}") - prompt_info['full_prompt_template'] = "Unable to generate example" - - # 如果有其他可用的属性,也可以提取 - if hasattr(instance, 'template'): - prompt_info['template_attr'] = str(instance.template)[:200] - - except Exception as e: - log.warning(f"[pipeline_assembler] Failed to instantiate {prompt_cls.__name__}: {e}") - prompt_info['full_prompt_template'] = "Unable to instantiate" - - return prompt_info - - -def choose_prompt_template_by_llm(op_name: str, state: DFState) -> str: - """ - 通过 LLM 选择最合适的 prompt_template - - 规则: - 1. 提取 operator 的所有 ALLOWED_PROMPTS 候选 - 2. 获取每个 prompt 的详细信息(包括提示词模板) - 3. 调用 LLM 让它根据 target 任务描述选择最合适的 prompt - 4. 返回选中 prompt 的实例化代码字符串 - - Args: - op_name: Operator 名称 - state: DFState 对象,包含 request.target 等信息 - - Returns: - 选中的 prompt_template 实例化代码字符串 - """ - cls = OPERATOR_REGISTRY.get(op_name) - if cls is None: - raise KeyError(f"Operator {op_name} not found in registry") - - # 如果没有 ALLOWED_PROMPTS 或为空,回退到原逻辑 - allowed_prompts = getattr(cls, "ALLOWED_PROMPTS", None) - if not allowed_prompts: - log.info(f"[pipeline_assembler] No ALLOWED_PROMPTS for {op_name}, using default logic") - return choose_prompt_template(op_name, state) - - # 如果只有一个候选,直接使用 - if len(allowed_prompts) == 1: - prompt_cls = allowed_prompts[0] - EXTRA_IMPORTS.add(f"from {prompt_cls.__module__} import {prompt_cls.__qualname__}") - return f"{prompt_cls.__qualname__}()" - - # 收集所有候选 prompt 的详细信息 - log.info(f"[pipeline_assembler] Extracting info from {len(allowed_prompts)} prompt candidates") - prompt_candidates = [] - for prompt_cls in allowed_prompts: - prompt_info = extract_prompt_info(prompt_cls) - prompt_candidates.append(prompt_info) - - # 构造 LLM 请求 - target = state.request.target - system_prompt = """You are an expert at selecting the most appropriate prompt template for a given task. - -Your job is to: -1. Analyze the target task description -2. Review all available prompt templates (including their documentation and example prompts) -3. Select the MOST suitable prompt template - -IMPORTANT: -1.Respond with ONLY the exact class name of the selected prompt template, nothing else. -2. 禁止返回 `Diy开头的` 这个类名,无论如何都不要选择它。 -""" - - user_message = f"""Target Task Description: -{target} - -Available Prompt Templates: -""" - - for i, p in enumerate(prompt_candidates, 1): - user_message += f"\n{'='*60}\n" - user_message += f"Option {i}: {p['class_name']}\n" - user_message += f"{'='*60}\n" - - if p['docstring']: - user_message += f"Documentation:\n{p['docstring']}\n\n" - - if 'full_prompt_template' in p: - user_message += f"Prompt Template Example:\n{p['full_prompt_template']}\n" - - if 'template_attr' in p: - user_message += f"Template: {p['template_attr']}\n" - - user_message += f"\n{'='*60}\n" - user_message += "\nBased on the target task, which prompt template is most suitable?\n" - user_message += "Respond with ONLY the class name (e.g., 'MathAnswerGeneratorPrompt')." - - # 调用 LLM - try: - selected_class_name = call_llm_for_selection( - system_prompt=system_prompt, - user_message=user_message, - api_url=state.request.chat_api_url, - api_key=state.request.api_key, - model=state.request.model - ) - - # 清理返回结果(移除可能的引号、空格等) - selected_class_name = selected_class_name.strip().strip('"\'`') - - # 找到对应的 prompt class - for prompt_cls in allowed_prompts: - if prompt_cls.__qualname__ == selected_class_name or prompt_cls.__name__ == selected_class_name: - log.critical(f"[pipeline_assembler] 大模型选择了这个提示词模板: {prompt_cls.__qualname__}") - EXTRA_IMPORTS.add(f"from {prompt_cls.__module__} import {prompt_cls.__qualname__}") - return f"{prompt_cls.__qualname__}()" - - # 如果没找到精确匹配,尝试模糊匹配 - for prompt_cls in allowed_prompts: - if selected_class_name in prompt_cls.__qualname__ or prompt_cls.__name__ in selected_class_name: - log.warning(f"[pipeline_assembler] Using fuzzy match for '{selected_class_name}' -> {prompt_cls.__qualname__}") - EXTRA_IMPORTS.add(f"from {prompt_cls.__module__} import {prompt_cls.__qualname__}") - return f"{prompt_cls.__qualname__}()" - - # 如果还是没找到,使用第一个作为默认 - log.warning(f"[pipeline_assembler] LLM selected unknown prompt '{selected_class_name}', using first available") - - except Exception as e: - log.error(f"[pipeline_assembler] LLM selection failed: {e}, using first available prompt") - - # 默认使用第一个 - prompt_cls = allowed_prompts[0] - EXTRA_IMPORTS.add(f"from {prompt_cls.__module__} import {prompt_cls.__qualname__}") - return f"{prompt_cls.__qualname__}()" - - -# ================================================================================================================================== -def snake_case(name: str) -> str: - """ - Convert CamelCase (with acronyms) to snake_case. - Examples: - SQLGenerator -> sql_generator - HTTPRequest -> http_request - """ - s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name) - s2 = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1) - return s2.replace("__", "_").lower() - - -def try_import(module_path: str) -> bool: - try: - importlib.import_module(module_path) - return True - except Exception as e: - log.warning(f"[pipeline_assembler] import {module_path} failed: {e}") - return False - - -def build_stub(cls_name: str, module_path: str) -> str: - return ( - f"# Fallback stub for {cls_name}, original module '{module_path}' not found\n" - f"class {cls_name}: # type: ignore\n" - f" def __init__(self, *args, **kwargs):\n" - f" import warnings; warnings.warn(\n" - f" \"Stub operator {cls_name} used, module '{module_path}' missing.\"\n" - f" )\n" - f" def run(self, *args, **kwargs):\n" - f" return kwargs.get(\"storage\") # 透传\n" - ) - - -def _normalize_module(mod: str) -> str: - """ - 将类似 - dataflow.operators.general_text.eval.langkit_sample_evaluator - 统一裁剪成 - dataflow.operators.general_text - - 规则: - 1. 仅处理以 "dataflow.operators." 开头的模块。 - 2. 只保留 "dataflow.operators.<一级子包>"。 - 3. 其余模块原样返回。 - """ - prefix = "dataflow.operators." - if mod.startswith(prefix): - # 拿掉前缀后按点分割,取第 0 个就是一级子包 - subpkg = mod[len(prefix):].split(".", 1)[0] - return f"{prefix}{subpkg}" - return mod - -def group_imports(op_names: List[str]) -> Tuple[List[str], List[str], Dict[str, type]]: - imports, stubs = [], [] - op_classes: Dict[str, type] = {} - module2names: Dict[str, List[str]] = defaultdict(list) - - for name in op_names: - cls = OPERATOR_REGISTRY.get(name) - if cls is None: - raise KeyError(f"Operator <{name}> not in OPERATOR_REGISTRY") - op_classes[name] = cls - - mod_raw = cls.__module__ # e.g. dataflow.operators.general_text.eval.langkit_sample_evaluator - mod = _normalize_module(mod_raw) # → dataflow.operators.general_text - - if try_import(mod): - module2names[mod].append(cls.__name__) - else: # 正常情况下进不到这里 - stubs.append(build_stub(cls.__name__, mod)) - - # 只保留一次循环 - for m in sorted(module2names.keys()): - uniq_names = sorted(set(module2names[m])) - imports.append(f"from {m} import {', '.join(uniq_names)}") - - # 追加由 choose_prompt_template 收集的 import - imports.extend(sorted(EXTRA_IMPORTS)) - return imports, stubs, op_classes - - -def _format_default(val: Any) -> str: - """ - Produce a code string for a default value. - If default is missing (inspect._empty), we return 'None' to keep code runnable. - """ - if val is inspect._empty: - return "None" - if isinstance(val, str): - return repr(val) - return repr(val) - - -def extract_op_params(cls: type) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], bool]: - """ - Inspect 'cls' for __init__ and run signatures. - - Returns: - init_kwargs: list of (param_name, code_str_default) for __init__ (excluding self) - run_kwargs: list of (param_name, code_str_default) for run (excluding self and storage) - run_has_storage: whether run(...) has 'storage' parameter - """ - # ---- __init__ - init_kwargs: List[Tuple[str, str]] = [] - try: - init_sig = inspect.signature(cls.__init__) - for p in list(init_sig.parameters.values())[1:]: # skip self - if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): - continue - init_kwargs.append((p.name, _format_default(p.default))) - except Exception as e: - log.warning(f"[pipeline_assembler] inspect __init__ of {cls.__name__} failed: {e}") - - # ---- run - run_kwargs: List[Tuple[str, str]] = [] - run_has_storage = False - if hasattr(cls, "run"): - try: - run_sig = inspect.signature(cls.run) - params = list(run_sig.parameters.values())[1:] # skip self - for p in params: - if p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): - continue - if p.name == "storage": - run_has_storage = True - continue - run_kwargs.append((p.name, _format_default(p.default))) - except Exception as e: - log.warning(f"[pipeline_assembler] inspect run of {cls.__name__} failed: {e}") - - return init_kwargs, run_kwargs, run_has_storage - -def choose_prompt_template(op_name: str, state: DFState) -> str: - """ - 返回 prompt_template 的代码字符串。 - 规则: - 1. 若类有 ALLOWED_PROMPTS 且非空 → 取第一个并实例化; - 2. 否则回退到 __init__ 默认值;若仍不可用则返回 None。 - """ - from dataflow.utils.registry import OPERATOR_REGISTRY - import inspect, json - - cls = OPERATOR_REGISTRY.get(op_name) - if cls is None: - raise KeyError(f"Operator {op_name} not found in registry") - - # 优先使用 ALLOWED_PROMPTS - if getattr(cls, "ALLOWED_PROMPTS", None): - prompt_cls = cls.ALLOWED_PROMPTS[0] - EXTRA_IMPORTS.add(f"from {prompt_cls.__module__} import {prompt_cls.__qualname__}") - return f"{prompt_cls.__qualname__}()" - - # -------- 无 ALLOWED_PROMPTS,兜底处理 -------- - sig = inspect.signature(cls.__init__) - p = sig.parameters.get("prompt_template") - if p is None: - # 理论上不会走到这里,因为调用方只在存在该参数时才进来 - return "None" - - default_val = p.default - if default_val in (inspect._empty, None): - return "None" - - # 基础类型可直接 repr - if isinstance(default_val, (str, int, float, bool)): - return repr(default_val) - - # 类型对象 → 加 import 然后实例化 - if isinstance(default_val, type): - EXTRA_IMPORTS.add(f"from {default_val.__module__} import {default_val.__qualname__}") - return f"{default_val.__qualname__}()" - - # UnionType / 其它复杂对象 → 字符串化再 repr,保证可写入代码 - return repr(str(default_val)) - - -def render_operator_blocks( - op_names: List[str], - op_classes: Dict[str, type], - state: DFState, - prompted_generator_prompts: Dict[int, str] = None -) -> Tuple[str, str]: - """ - Render operator initialization lines and forward-run lines without leading indentation. - Indentation will be applied by build_pipeline_code when inserting into the template. - - Args: - op_names: 算子名称列表 - op_classes: 算子类字典 - state: DFState 对象 - prompted_generator_prompts: 预生成的 PromptedGenerator system_prompt 映射 - 格式: {算子索引: system_prompt} - """ - init_lines: List[str] = [] - forward_lines: List[str] = [] - prompted_generator_prompts = prompted_generator_prompts or {} - - # 用于跟踪每个算子名称的出现次数 - name_count: Dict[str, int] = {} - - for i, name in enumerate(op_names): - cls = op_classes[name] - base_var_name = snake_case(cls.__name__) - - # 统计相同算子名称的出现次数 - count = name_count.get(base_var_name, 0) + 1 - name_count[base_var_name] = count - - # 如果出现多次,添加序号后缀 - if count > 1: - var_name = f"{base_var_name}_{count}" - else: - var_name = base_var_name - - init_kwargs, run_kwargs, run_has_storage = extract_op_params(cls) - - # Inject pipeline context where appropriate - rendered_init_args: List[str] = [] - for k, v in init_kwargs: - if k == "llm_serving": - rendered_init_args.append(f"{k}=self.llm_serving") - elif k == "prompt_template": - # p_t = choose_prompt_template(name, state) - # 用LLM来选择 - p_t = choose_prompt_template_by_llm(name, state) - rendered_init_args.append(f'{k}={p_t}') - elif k == "system_prompt": - # 检查是否是 PromptedGenerator 且有预生成的 prompt - if name == "PromptedGenerator" and i in prompted_generator_prompts: - # 使用预生成的 system_prompt - pre_prompt = prompted_generator_prompts[i] - rendered_init_args.append(f'{k}={repr(pre_prompt)}') - log.info(f"[render_operator_blocks] 使用预生成的 prompt 给索引 {i} 的 PromptedGenerator") - else: - rendered_init_args.append(f"{k}={v}") - else: - rendered_init_args.append(f"{k}={v}") - - init_line = f"self.{var_name} = {cls.__name__}(" + ", ".join(rendered_init_args) + ")" - init_lines.append(init_line) - - # Build run call - run_args: List[str] = [] - if run_has_storage: - run_args.append("storage=self.storage.step()") - run_args.extend([f"{k}={v}" for k, v in run_kwargs]) - - if run_args: - call = ( - f"self.{var_name}.run(\n" - f" " + ", ".join(run_args) + "\n" - f")" - ) - else: - call = f"self.{var_name}.run()" - forward_lines.append(call) - - return "\n".join(init_lines), "\n".join(forward_lines) - - -def indent_block(code: str, spaces: int) -> str: - """ - Indent every line of 'code' by 'spaces' spaces. Keeps internal structure. - """ - import textwrap as _tw - code = _tw.dedent(code or "").strip("\n") - if not code: - return "" - prefix = " " * spaces - return "\n".join(prefix + line if line else "" for line in code.splitlines()) - - -def write_pipeline_file( - code: str, - file_name: str = "recommend_pipeline.py", - overwrite: bool = True, -) -> Path: - """ - 把生成的 pipeline 代码写入当前文件同级目录下的 `file_name`。 - """ - target_path = Path(__file__).resolve().parent / file_name - - if target_path.exists() and not overwrite: - raise FileExistsError(f"{target_path} already exists. Set overwrite=True to replace it.") - - target_path.write_text(code, encoding="utf-8") - log.info(f"[pipeline_assembler] code written to {target_path}") - - return target_path - -# =========================================================渲染 op 的 全部函数================================================== -# def snake_case(name: str) -> str: -# """CamelCase -> snake_case""" -# s1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name) -# s2 = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s1) -# return s2.replace("__", "_").lower() - - -# def try_import(module_path: str) -> bool: -# """尝试导入模块""" -# try: -# importlib.import_module(module_path) -# return True -# except Exception as e: -# log.warning(f"import {module_path} failed: {e}") -# return False - - -# def _normalize_module(mod: str) -> str: -# """dataflow.operators.xxx.yyy.zzz -> dataflow.operators.xxx""" -# prefix = "dataflow.operators." -# if mod.startswith(prefix): -# subpkg = mod[len(prefix):].split(".", 1)[0] -# return f"{prefix}{subpkg}" -# return mod - - -def group_import_for_full_params(op_names: List[str]) -> tuple: - """收集所有算子的导入语句""" - imports = [] - op_classes: Dict[str, type] = {} - module2names: Dict[str, List[str]] = defaultdict(list) - - for name in op_names: - cls = OPERATOR_REGISTRY.get(name) - if cls is None: - raise KeyError(f"Operator <{name}> not in OPERATOR_REGISTRY") - op_classes[name] = cls - - mod_raw = cls.__module__ - mod = _normalize_module(mod_raw) - - if try_import(mod): - module2names[mod].append(cls.__name__) - - for m in sorted(module2names.keys()): - uniq_names = sorted(set(module2names[m])) - imports.append(f"from {m} import {', '.join(uniq_names)}") - - # 追加额外的 import - # imports.extend(sorted(EXTRA_IMPORTS)) - return imports, op_classes - -# def render_operator_blocks_with_full_params( -# opname_and_params: List[Dict[str, Any]], -# op_classes: Dict[str, type] -# ) -> tuple: -# """ -# 渲染算子初始化和调用代码(完整支持 init + run 参数) - -# Args: -# opname_and_params: [ -# { -# "op_name": "OperatorA", -# "init_params": {"llm_serving": "...", "prompt_template": "..."}, -# "run_params": {"param1": "value1"} -# }, -# ... -# ] - -# Returns: -# (init_code_block, forward_code_block) -# """ -# import inspect - -# init_lines = [] -# forward_lines = [] -# # 记录同一算子类出现的次数,避免 self.xxx 被后面的覆盖 -# name_count: Dict[str, int] = {} - -# for item in opname_and_params: -# name = item["op_name"] -# custom_init_params = item.get("init_params", {}) -# custom_run_params = item.get("run_params", {}) - -# cls = op_classes[name] -# base_var_name = snake_case(cls.__name__) -# count = name_count.get(base_var_name, 0) + 1 -# name_count[base_var_name] = count -# if count > 1: -# var_name = f"{base_var_name}_{count}" -# else: -# var_name = base_var_name - -# # 检查 run 方法是否有 storage 参数 -# run_has_storage = False -# if hasattr(cls, "run"): -# try: -# run_sig = inspect.signature(cls.run) -# run_has_storage = "storage" in run_sig.parameters -# except: -# pass - -# # -------- 渲染 __init__ 参数 -------- -# init_args = [] -# for k, v in custom_init_params.items(): -# if k == "llm_serving": -# init_args.append(f"{k}=self.llm_serving") -# elif k == "prompt_template": -# # 用户已经选择了具体的 prompt 类 -# if v and v != "None": -# # v 格式:module.ClassName -# parts = v.rsplit(".", 1) -# if len(parts) == 2: -# module, classname = parts -# EXTRA_IMPORTS.add(f"from {module} import {classname}") -# init_args.append(f"{k}={classname}()") -# else: -# init_args.append(f"{k}={v}") -# else: -# init_args.append(f"{k}=None") -# else: -# # 其他参数直接使用 -# if isinstance(v, str): -# init_args.append(f"{k}={repr(v)}") -# else: -# init_args.append(f"{k}={v}") -# init_args.insert(0, 'llm_serving=self.llm_serving') #前端没传这个,直接塞进来; -# init_line = f"self.{var_name} = {cls.__name__}({', '.join(init_args)})" -# init_lines.append(init_line) - -# # -------- 渲染 run() 调用参数 -------- -# run_args = [] -# if run_has_storage: -# run_args.append("storage=self.storage.step()") - -# for k, v in custom_run_params.items(): -# if isinstance(v, str): -# run_args.append(f"{k}={repr(v)}") -# else: -# run_args.append(f"{k}={v}") - -# if run_args: -# separator = ',\n ' -# call = ( -# f"self.{var_name}.run(\n" -# f" {separator.join(run_args)}\n" -# f" )" -# ) -# else: -# call = f"self.{var_name}.run()" -# forward_lines.append(call) - -# return "\n ".join(init_lines), "\n ".join(forward_lines) - -def render_operator_blocks_with_full_params( - opname_and_params: List[Dict[str, Any]], - op_classes: Dict[str, type], - prompted_generator_prompts: Optional[Dict[int, str]] = None # ← 添加参数 -) -> tuple: - """ - 渲染算子初始化和调用代码(完整支持 init + run 参数) - """ - import inspect - - init_lines = [] - forward_lines = [] - name_count: Dict[str, int] = {} - prompted_gen_counter = 0 - - for idx, item in enumerate(opname_and_params): # ← 使用 enumerate 获取索引 - name = item["op_name"] - # 支持两种格式:1) init_params/run_params 分离 2) 统一的 params - custom_init_params = item.get("init_params", {}) - custom_run_params = item.get("run_params", {}) - - # 如果没有 init_params/run_params,尝试从 params 中获取 - if not custom_init_params and not custom_run_params: - all_params = item.get("params", {}) - # 根据算子的 __init__ 和 run 签名自动分配参数 - cls = op_classes[name] - init_param_names = set() - run_param_names = set() - - try: - init_sig = inspect.signature(cls.__init__) - init_param_names = {p.name for p in list(init_sig.parameters.values())[1:] - if p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)} - except: - pass - - try: - run_sig = inspect.signature(cls.run) - run_param_names = {p.name for p in list(run_sig.parameters.values())[1:] - if p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - and p.name != "storage"} - except: - pass - - # 分配参数 - for k, v in all_params.items(): - if k in init_param_names: - custom_init_params[k] = v - elif k in run_param_names: - custom_run_params[k] = v - else: - # 默认放到 run_params - custom_run_params[k] = v - - cls = op_classes[name] - base_var_name = snake_case(cls.__name__) - count = name_count.get(base_var_name, 0) + 1 - name_count[base_var_name] = count - if count > 1: - var_name = f"{base_var_name}_{count}" - else: - var_name = base_var_name - - # 检查 run 方法是否有 storage 参数 - run_has_storage = False - if hasattr(cls, "run"): - try: - run_sig = inspect.signature(cls.run) - run_has_storage = "storage" in run_sig.parameters - except: - pass - - # -------- 渲染 __init__ 参数 -------- - init_args = [] - - # 检查是否是 PromptedGenerator 且有预生成的 prompt - if name == "PromptedGenerator" and prompted_generator_prompts: - if prompted_gen_counter in prompted_generator_prompts: - pre_prompt = prompted_generator_prompts[prompted_gen_counter] - custom_init_params["system_prompt"] = pre_prompt - log.info(f"[render_operator_blocks_with_full_params] 使用预生成的 prompt (索引 {prompted_gen_counter})") - - for k, v in custom_init_params.items(): - if k == "llm_serving": - init_args.append(f"{k}=self.llm_serving") - elif k == "system_prompt": - init_args.append(f"{k}={repr(v)}") - elif k == "prompt_template": - if v and v != "None": - parts = v.rsplit(".", 1) - if len(parts) == 2: - module, classname = parts - EXTRA_IMPORTS.add(f"from {module} import {classname}") - init_args.append(f"{k}={classname}()") - else: - init_args.append(f"{k}={v}") - else: - init_args.append(f"{k}=None") - else: - if isinstance(v, str): - init_args.append(f"{k}={repr(v)}") - else: - init_args.append(f"{k}={v}") - - # ← 如果是 PromptedGenerator,递增计数器 - if name == "PromptedGenerator": - prompted_gen_counter += 1 - - init_args.insert(0, 'llm_serving=self.llm_serving') - init_line = f"self.{var_name} = {cls.__name__}({', '.join(init_args)})" - init_lines.append(init_line) - - # -------- 渲染 run() 调用参数 -------- - run_args = [] - if run_has_storage: - run_args.append("storage=self.storage.step()") - - for k, v in custom_run_params.items(): - # 跳过 storage 参数,因为已经在上面自动添加了 - if k == "storage": - continue - if isinstance(v, str): - run_args.append(f"{k}={repr(v)}") - else: - run_args.append(f"{k}={v}") - - if run_args: - separator = ',\n ' - call = ( - f"self.{var_name}.run(\n" - f" {separator.join(run_args)}\n" - f" )" - ) - else: - call = f"self.{var_name}.run()" - forward_lines.append(call) - - return "\n ".join(init_lines), "\n ".join(forward_lines) - - -def build_pipeline_code_with_full_params( - opname_and_params: List[Dict[str, Any]], - *, - cache_dir: str = "./cache_local", - llm_local: bool = False, - local_model_path: str = "", - chat_api_url: str = "", - model_name: str = "deepseek-v3.2", - file_path: str = "", - prompted_generator_prompts: Optional[Dict[int, str]] = None, -) -> str: - """构建完整的 pipeline 代码(支持 init + run 参数)""" - - # 清空之前的额外导入 - EXTRA_IMPORTS.clear() - - # 1) 提取所有算子名称 - op_names = [item["op_name"] for item in opname_and_params] - - # 2) 判断 cache_type - file_suffix = Path(file_path).suffix.lower() if file_path else "" - cache_type = { - ".jsonl": "jsonl", - ".json": "json", - ".csv": "csv" - }.get(file_suffix, "jsonl") - - # 3) 收集导入 - import_lines, op_classes = group_import_for_full_params(op_names) - - # 4) 渲染算子代码(传入 prompted_generator_prompts) - ops_init_block, forward_block = render_operator_blocks_with_full_params( - opname_and_params, op_classes, prompted_generator_prompts=prompted_generator_prompts - ) - - # 汇总所有导入语句,去重排序 - all_imports = import_lines + sorted(EXTRA_IMPORTS) - import_section = "\n".join(dict.fromkeys(all_imports)) - - # 5) LLM Serving(生成无缩进的代码块) - if llm_local: - llm_block = f'''# -------- LLM Serving (Local) -------- -self.llm_serving = LocalModelLLMServing_vllm( - hf_model_name_or_path="{local_model_path}", - vllm_tensor_parallel_size=1, - vllm_max_tokens=8192, - hf_local_dir="local", - model_name="{model_name}", -)''' - else: - llm_block = f'''# -------- LLM Serving (Remote) -------- -self.llm_serving = APILLMServing_request( - api_url="{chat_api_url}chat/completions", - key_name_of_api_key="DF_API_KEY", - model_name="{model_name}", - max_workers=100, -)''' - - # 6) 模板 - template = '''""" -Auto-generated Pipeline (supports init + run params) -""" -from dataflow.pipeline import PipelineABC -from dataflow.utils.storage import FileStorage -from dataflow.serving import APILLMServing_request, LocalModelLLMServing_vllm - -{import_section} - -class RecommendPipeline(PipelineABC): - def __init__(self): - super().__init__() - # -------- FileStorage -------- - self.storage = FileStorage( - first_entry_file_name="{file_path}", - cache_path="{cache_dir}", - file_name_prefix="dataflow_cache_step", - cache_type="{cache_type}", - ) - {llm_block} - - # -------- Operators -------- - {ops_init_block} - - def forward(self): - {forward_block} - -if __name__ == "__main__": - pipeline = RecommendPipeline() - pipeline.compile() - pipeline.forward() -''' - - code = template.format( - file_path=file_path, - import_section=import_section, - cache_dir=cache_dir, - cache_type=cache_type, - llm_block=llm_block, - ops_init_block=ops_init_block, - forward_block=forward_block, - ) - - return code - - - -# =========================================================只渲染run函数,其余不管:================================================== -def render_operator_blocks_with_params( - opname_and_params: List[Dict[str, Any]], - op_classes: Dict[str, type], - state: DFState -) -> Tuple[str, str]: - """ - 渲染算子初始化和调用代码,支持自定义 run 函数参数 - - Args: - opname_and_params: 算子名称和参数列表,格式: [{"op_name": "xxx", "params": {...}}, ...] - op_classes: 算子类字典 - state: DFState 对象 - - Returns: - (初始化代码块, forward调用代码块) - """ - init_lines: List[str] = [] - forward_lines: List[str] = [] - # 记录同一算子类出现的次数,避免 self.xxx 被后面的覆盖 - name_count: Dict[str, int] = {} - - for item in opname_and_params: - name = item["op_name"] - custom_params = item.get("params", {}) # 获取自定义的 run 参数 - - cls = op_classes[name] - base_var_name = snake_case(cls.__name__) - count = name_count.get(base_var_name, 0) + 1 - name_count[base_var_name] = count - if count > 1: - var_name = f"{base_var_name}_{count}" - else: - var_name = base_var_name - - init_kwargs, run_kwargs, run_has_storage = extract_op_params(cls) - - # -------- 渲染 __init__ 参数 -------- - rendered_init_args: List[str] = [] - for k, v in init_kwargs: - if k == "llm_serving": - rendered_init_args.append(f"{k}=self.llm_serving") - elif k == "prompt_template": - p_t = choose_prompt_template_by_llm(name, state) - rendered_init_args.append(f'{k}={p_t}') - else: - rendered_init_args.append(f"{k}={v}") - - init_line = f"self.{var_name} = {cls.__name__}(" + ", ".join(rendered_init_args) + ")" - init_lines.append(init_line) - - # -------- 渲染 run() 调用参数 -------- - run_args: List[str] = [] - - # 第一个参数 storage 保持不变 - if run_has_storage: - run_args.append("storage=self.storage.step()") - - # 处理其他参数:优先使用自定义参数,否则使用默认值 - for k, default_v in run_kwargs: - if k in custom_params: - # 使用自定义参数值,需要正确格式化 - actual_value = custom_params[k] - if isinstance(actual_value, str): - formatted_value = repr(actual_value) - elif isinstance(actual_value, (int, float, bool, type(None))): - formatted_value = repr(actual_value) - elif isinstance(actual_value, (list, dict)): - formatted_value = repr(actual_value) - else: - # 其他类型尝试转字符串 - formatted_value = repr(actual_value) - run_args.append(f"{k}={formatted_value}") - else: - # 使用默认值 - run_args.append(f"{k}={default_v}") - - # 构建完整的 run 调用 - if run_args: - call = ( - f"self.{var_name}.run(\n" - f" " + ",\n ".join(run_args) + "\n" - f")" - ) - else: - call = f"self.{var_name}.run()" - forward_lines.append(call) - - return "\n".join(init_lines), "\n".join(forward_lines) - - -def build_pipeline_code_with_run_params( - opname_and_params: List[Dict[str, Any]], - state: DFState, - *, - cache_dir: str = "./cache_local", - llm_local: bool = False, - local_model_path: str = "", - chat_api_url: str = "", - model_name: str = "deepseek-v3.2", - file_path: str = "", - prompted_generator_prompts: Optional[Dict[int, str]] = None -) -> str: - """ - 构建 pipeline 代码,支持为每个算子指定 run 函数的实际参数 - - 注意: - - render_operator_blocks_with_full_params 返回的代码已经包含了内部缩进 - (使用 "\n ".join()),因此不需要再次使用 indent_block - - Args: - opname_and_params: 算子名称和参数列表 - 格式: [ - {"op_name": "OperatorA", "params": {"param1": "value1", "param2": 123}}, - {"op_name": "OperatorB", "params": {"param_x": True}}, - ... - ] - 其中 params 是该算子 run 函数的参数(不包括 storage) - state: DFState 对象 - cache_dir: 缓存目录 - llm_local: 是否使用本地 LLM - local_model_path: 本地模型路径 - chat_api_url: API URL - model_name: 模型名称 - file_path: 输入文件路径 - prompted_generator_prompts: 预生成的 PromptedGenerator system_prompt 映射 - - Returns: - 生成的 pipeline 代码字符串 - """ - # 1) 提取所有算子名称 - op_names = [item["op_name"] for item in opname_and_params] - - # 2) 根据 file_path 后缀判断 cache_type - file_suffix = Path(file_path).suffix.lower() if file_path else "" - if file_suffix == ".jsonl": - cache_type = "jsonl" - elif file_suffix == ".json": - cache_type = "json" - elif file_suffix == ".csv": - cache_type = "csv" - else: - cache_type = "jsonl" - log.warning(f"[pipeline_assembler] Unknown file suffix '{file_suffix}', defaulting to 'jsonl'") - - # 3) 收集导入与类 - import_lines, stub_blocks, op_classes = group_imports(op_names) - - # 4) 渲染 operator 代码片段 - # render_operator_blocks_with_full_params 返回的代码已经包含了正确的缩进 - ops_init_block, forward_block = render_operator_blocks_with_full_params( - opname_and_params, - op_classes, - prompted_generator_prompts=prompted_generator_prompts - ) - - import_lines.extend(sorted(EXTRA_IMPORTS)) - - import_section = "\n".join(import_lines) - stub_section = "\n\n".join(stub_blocks) - - # 5) LLM-Serving 片段(无缩进,统一在模板中缩进) - if llm_local: - llm_block_raw = f""" -# -------- LLM Serving (Local) -------- -self.llm_serving = LocalModelLLMServing_vllm( - hf_model_name_or_path="{local_model_path}", - vllm_tensor_parallel_size=1, - vllm_max_tokens=8192, - hf_local_dir="local", - model_name="{model_name}", -) -""" - else: - llm_block_raw = f""" -# -------- LLM Serving (Remote) -------- -self.llm_serving = APILLMServing_request( - api_url="{chat_api_url}chat/completions", - key_name_of_api_key="DF_API_KEY", - model_name="{model_name}", - max_workers=100, -) -""" - - # 6) 统一缩进 llm_block - llm_block = indent_block(llm_block_raw, 8) - - # 7) 模板 - template = '''""" -Auto-generated by pipeline_assembler (with custom run params) -""" -from dataflow.pipeline import PipelineABC -from dataflow.utils.storage import FileStorage -from dataflow.serving import APILLMServing_request, LocalModelLLMServing_vllm - -{import_section} - -{stub_section} - -class RecommendPipeline(PipelineABC): - def __init__(self): - super().__init__() - # -------- FileStorage -------- - self.storage = FileStorage( - first_entry_file_name="{file_path}", - cache_path="{cache_dir}", - file_name_prefix="dataflow_cache_step", - cache_type="{cache_type}", - ) -{llm_block} - # -------- Operators -------- - {ops_init_block} - - def forward(self): - {forward_block} - -if __name__ == "__main__": - pipeline = RecommendPipeline() - pipeline.compile() - pipeline.forward() -''' - - # 8) 格式化并返回 - code = template.format( - file_path=file_path, - import_section=import_section, - stub_section=stub_section, - cache_dir=cache_dir, - cache_type=cache_type, - llm_block=llm_block, - ops_init_block=ops_init_block, - forward_block=forward_block, - ) - return code - - -def pipeline_assembler_with_params( - opname_and_params: List[Dict[str, Any]], - state: DFState, - **kwargs -) -> Dict[str, Any]: - """ - Pipeline 组装器(支持自定义 run 参数版本) - - Args: - opname_and_params: 算子名称和参数列表 - 格式: [{"op_name": "xxx", "params": {...}}, ...] - state: DFState 对象 - **kwargs: 其他参数传递给 build_pipeline_code_with_run_params - - Returns: - 包含生成代码的字典 {"pipe_code": ...} - """ - code = build_pipeline_code_with_run_params(opname_and_params, state, **kwargs) - return {"pipe_code": code} - - -async def apipeline_assembler_with_params( - opname_and_params: List[Dict[str, Any]], - state: DFState, - **kwargs -) -> Dict[str, Any]: - """异步版本""" - return pipeline_assembler_with_params(opname_and_params, state, **kwargs) - - -# ================================之前的版本 -def build_pipeline_code( - op_names: List[str], - state: DFState, - *, - cache_dir: str = "./cache_local", - llm_local: bool = False, - local_model_path: str = "", - chat_api_url: str = "", - model_name: str = "deepseek-v3.2", - file_path: str = "", - prompted_generator_prompts: Optional[Dict[int, str]] = None, # ← 添加这个参数 -) -> str: - # 1) 根据 file_path 后缀判断 cache_type - file_suffix = Path(file_path).suffix.lower() if file_path else "" - if file_suffix == ".jsonl": - cache_type = "jsonl" - elif file_suffix == ".json": - cache_type = "json" - elif file_suffix == ".csv": - cache_type = "csv" - else: - cache_type = "jsonl" - log.warning(f"[pipeline_assembler] Unknown file suffix '{file_suffix}', defaulting to 'jsonl'") - - # 2) 收集导入与类 - import_lines, stub_blocks, op_classes = group_imports(op_names) - - # 3) 渲染 operator 代码片段(传递 prompted_generator_prompts) - ops_init_block_raw, forward_block_raw = render_operator_blocks( - op_names, - op_classes, - state, - prompted_generator_prompts=prompted_generator_prompts - ) - - import_lines.extend(sorted(EXTRA_IMPORTS)) - - import_section = "\n".join(import_lines) - stub_section = "\n\n".join(stub_blocks) - - # 4) LLM-Serving 片段(无缩进,统一在模板中缩进) - if llm_local: - llm_block_raw = f""" -# -------- LLM Serving (Local) -------- -self.llm_serving = LocalModelLLMServing_vllm( - hf_model_name_or_path="{local_model_path}", - vllm_tensor_parallel_size=1, - vllm_max_tokens=8192, - hf_local_dir="local", - model_name="{model_name}", -) -""" - else: - llm_block_raw = f""" -# -------- LLM Serving (Remote) -------- -self.llm_serving = APILLMServing_request( - api_url="{chat_api_url}chat/completions", - key_name_of_api_key="DF_API_KEY", - model_name="{model_name}", - max_workers=100, -) -""" - - # 5) 统一缩进 - llm_block = indent_block(llm_block_raw, 8) - ops_init_block = indent_block(ops_init_block_raw, 8) - forward_block = indent_block(forward_block_raw, 8) - - # 6) 模板 - template = '''""" -Auto-generated by pipeline_assembler -""" -from dataflow.pipeline import PipelineABC -from dataflow.utils.storage import FileStorage -from dataflow.serving import APILLMServing_request, LocalModelLLMServing_vllm - -{import_section} - -{stub_section} - -class RecommendPipeline(PipelineABC): - def __init__(self): - super().__init__() - # -------- FileStorage -------- - self.storage = FileStorage( - first_entry_file_name="{file_path}", - cache_path="{cache_dir}", - file_name_prefix="dataflow_cache_step", - cache_type="{cache_type}", - ) -{llm_block} - -{ops_init_block} - - def forward(self): -{forward_block} - -if __name__ == "__main__": - pipeline = RecommendPipeline() - pipeline.compile() - pipeline.forward() -''' - - # 7) 格式化并返回 - code = template.format( - file_path=file_path, - import_section=import_section, - stub_section=stub_section, - cache_dir=cache_dir, - cache_type=cache_type, - llm_block=llm_block, - ops_init_block=ops_init_block, - forward_block=forward_block, - ) - return code - - -def pipeline_assembler(recommendation: List[str], state: DFState,**kwargs) -> Dict[str, Any]: - code = build_pipeline_code(recommendation, state, **kwargs) - return {"pipe_code": code} - - -async def apipeline_assembler(recommendation: List[str], **kwargs) -> Dict[str, Any]: - return pipeline_assembler(recommendation, **kwargs) - -# ===================================================================通过my pipline的 py文件,拿到结构化的输出信息 -""" -Parse a generated PipelineABC python file and export a graph schema:: - - { - "nodes": [...], - "edges": [...] - } - -Requirements: - - 支持 input_key / output_key 既可以是关键字参数也可以是位置参数 - - 允许同一个算子 run 多次 - - nodes.id 直接使用 self.xxx 的变量名 -""" -from collections import defaultdict -from dataflow.utils.registry import OPERATOR_REGISTRY - -# ----------------------------------------------------- # -# config & helpers -# ----------------------------------------------------- # -SKIP_CLASSES: set[str] = { - "FileStorage", - "APILLMServing_request", - "LocalModelLLMServing_vllm", -} - -_IN_PREFIXES = ("input", "input_") -_OUT_PREFIXES = ("output", "output_") - - -def _is_input(name: str) -> bool: - return name.startswith(_IN_PREFIXES) - - -def _is_output(name: str) -> bool: - return name.startswith(_OUT_PREFIXES) - - -def _guess_type(cls_obj: type | None, cls_name: str) -> str: - """ - Guess operator category for front-end icon & color. - 规则: - 1. package 名倒数第二段 (operators.xxx.{filter|parser}.xxx) - 2. 类名后缀启发 - 3. 兜底 'other' - """ - # rule-1 - if cls_obj is not None: - parts = cls_obj.__module__.split(".") - if len(parts) >= 2: - candidate = parts[-2] - if candidate not in {"__init__", "__main__"}: - return candidate - # rule-2 - lower = cls_name.lower() - for suf, cat in [ - ("parser", "parser"), - ("generator", "generate"), - ("filter", "filter"), - ("evaluator", "eval"), - ("refiner", "refine"), - ]: - if lower.endswith(suf): - return cat - # rule-3 - return "other" - - -def _literal_eval_safe(node: ast.AST) -> Any: - """ast.literal_eval 的宽松版本,失败就返回反编译字符串""" - if isinstance(node, ast.Constant): # fast path - return node.value - try: - return ast.literal_eval(node) - except Exception: - return ast.unparse(node) if hasattr(ast, "unparse") else repr(node) - - -# ----------------------------------------------------- # -# AST 解析主流程 -# ----------------------------------------------------- # -def parse_pipeline_file(file_path: str | Path) -> Dict[str, Any]: - """ - Parameters - ---------- - file_path : str | Path - 生成的 pipeline python 文件路径 - - Returns - ------- - dict - {"nodes": [...], "edges": [...]} - """ - file_path = Path(file_path) - src = file_path.read_text(encoding="utf-8") - tree = ast.parse(src, filename=str(file_path)) - - # ------------------------------------------------- # - # 1. 解析 __init__ 里的 operator 实例 - # ------------------------------------------------- # - def _parse_init(init_func: ast.FunctionDef) -> Dict[str, Tuple[str, Dict[str, Any]]]: - """ - Returns - ------- - var_name -> (cls_name, init_kwargs) - """ - results: Dict[str, Tuple[str, Dict[str, Any]]] = {} - for stmt in init_func.body: - if ( - isinstance(stmt, ast.Assign) - and stmt.targets - and isinstance(stmt.targets[0], ast.Attribute) - and isinstance(stmt.value, ast.Call) - ): - attr: ast.Attribute = stmt.targets[0] - if not (isinstance(attr.value, ast.Name) and attr.value.id == "self"): - continue - var_name = attr.attr - - call: ast.Call = stmt.value - # 取类名 - if isinstance(call.func, ast.Name): - cls_name = call.func.id - elif isinstance(call.func, ast.Attribute): - cls_name = call.func.attr - else: - continue - - if cls_name in SKIP_CLASSES: # 跳过非算子 - continue - - kwargs = { - kw.arg: _literal_eval_safe(kw.value) - for kw in call.keywords - if kw.arg is not None - } - results[var_name] = (cls_name, kwargs) - return results - - # ------------------------------------------------- # - # 2. 解析 forward() 里的 run 调用 - # ------------------------------------------------- # - def _parse_forward( - forward_func: ast.FunctionDef, - ) -> DefaultDict[str, List[Dict[str, Any]]]: - """ - Returns - ------- - var_name -> [run_kwargs ...] (保持出现顺序) - """ - mapping: DefaultDict[str, List[Dict[str, Any]]] = defaultdict(list) - - # walk 按源码顺序遍历需借助 ast.iter_child_nodes + 递归 - def _visit(node: ast.AST): - # 按出现顺序遍历 - for child in ast.iter_child_nodes(node): - if ( - isinstance(child, ast.Call) - and isinstance(child.func, ast.Attribute) - and child.func.attr == "run" - ): - obj = child.func.value - if ( - isinstance(obj, ast.Attribute) - and isinstance(obj.value, ast.Name) - and obj.value.id == "self" - ): - var_name = obj.attr - - # ------- 关键字参数 ------- - kw_dict = { - kw.arg: _literal_eval_safe(kw.value) - for kw in child.keywords - if kw.arg is not None - } - - # ------- 位置参数 ------- - # 假设位置顺序为 (storage, input_key, output_key, ...) - if len(child.args) >= 2: - kw_dict.setdefault("input_key", _literal_eval_safe(child.args[1])) - if len(child.args) >= 3: - kw_dict.setdefault("output_key", _literal_eval_safe(child.args[2])) - - mapping[var_name].append(kw_dict) - _visit(child) - - _visit(forward_func) - return mapping - - # ------------------------------------------------- # - # 3. 主 visitor:定位唯一继承 PipelineABC 的类 - # ------------------------------------------------- # - init_ops, forward_calls = {}, defaultdict(list) - - class PipelineVisitor(ast.NodeVisitor): - def visit_ClassDef(self, node: ast.ClassDef): # noqa: N802 - nonlocal init_ops, forward_calls - # naive 判断: 存在 forward() 方法即认为是 pipeline - has_forward = any( - isinstance(b, ast.FunctionDef) and b.name == "forward" for b in node.body - ) - if not has_forward: - return - for item in node.body: - if isinstance(item, ast.FunctionDef): - if item.name == "__init__": - init_ops = _parse_init(item) - elif item.name == "forward": - forward_calls = _parse_forward(item) - - PipelineVisitor().visit(tree) - - # ------------------------------------------------- # - # 4. build nodes - # ------------------------------------------------- # - def build_nodes() -> tuple[list[dict[str, Any]], - dict[str, str], - dict[str, tuple[str, str]]]: - """ - Returns - ------- - nodes : list of node-dict - var2id : var_name -> node_id (供后续查表) - produced_ports : label(str) -> (node_id, port_name) - """ - nodes: list[dict[str, Any]] = [] - var2id: dict[str, str] = {} - produced_ports: dict[str, tuple[str, str]] = {} - - global_counter = itertools.count(1) - - for var, (cls_name, init_kwargs) in init_ops.items(): - # -------- 生成 node_id -------- # - node_id = f"node{next(global_counter)}" # <-- 变成 node1/node2/… - - var2id[var] = node_id - - # forward() 第一次 run 的配置 - first_run_cfg = forward_calls.get(var, [{}])[0] - - # 把首次 run 产生的 output 标记为 “已经产生” - for k, v in first_run_cfg.items(): - if _is_output(k) and isinstance(v, str): - produced_ports[v] = (node_id, k) - try: - cls_obj = OPERATOR_REGISTRY.get(cls_name) - except Exception: - cls_obj = None - - nodes.append( - { - "id": node_id, - "name": cls_name, - "type": _guess_type(cls_obj, cls_name), - "config": { - "init": init_kwargs, - "run": first_run_cfg, - }, - } - ) - return nodes, var2id, produced_ports - - # ------------------------------------------------- # - # 5. build edges (按 forward 执行顺序) - # ------------------------------------------------- # - def build_edges( - produced_ports: dict[str, tuple[str, str]], - var2id: dict[str, str], - ) -> list[dict[str, Any]]: - edges: list[dict[str, Any]] = [] - for var, runs in forward_calls.items(): - tgt_id = var2id.get(var) - if not tgt_id: - continue - for run_cfg in runs: - for k, v in run_cfg.items(): - if _is_input(k) and isinstance(v, str) and v in produced_ports: - src_id, src_port = produced_ports[v] - edges.append( - { - "source": src_id, - "target": tgt_id, - "source_port": src_port, - "target_port": k, - } - ) - return edges - - nodes, var2id, produced_ports = build_nodes() - edges = build_edges(produced_ports, var2id) - return {"nodes": nodes, "edges": edges} - -def build_edges_from_nodes( - nodes: List[Dict[str, Any]] | Dict[str, Any], - save_path: str | Path -) -> Dict[str, Any]: - """ - 根据 nodes 自动生成 edges,并保存完整的 pipeline graph 到 save_path - - Args: - nodes: 节点信息,支持两种格式: - - List[Dict]: 直接的节点列表 - - Dict: 包含 "nodes" 键的字典,如 {"nodes": [...]} - save_path: 输出 json 文件路径 - - Returns: - 完整的 graph dict: {"nodes": ..., "edges": ...} - - Raises: - ValueError: 当输入格式不正确时 - """ - - # 统一处理输入格式,提取节点列表 - if isinstance(nodes, dict): - if "nodes" in nodes: - nodes_list = nodes["nodes"] - else: - raise ValueError("当 nodes 为字典时,必须包含 'nodes' 键") - elif isinstance(nodes, list): - nodes_list = nodes - else: - raise ValueError(f"nodes 必须是 list 或 dict 类型,当前类型: {type(nodes)}") - - # 验证节点列表不为空 - if not nodes_list: - log.warning("[build_edges_from_nodes] 节点列表为空") - graph = {"nodes": [], "edges": []} - save_path = Path(save_path) - save_path.write_text(json.dumps(graph, indent=2, ensure_ascii=False), encoding="utf-8") - return graph - - # 1. 收集所有 output_key 到 (node_id, port_name) 的映射 - produced_outputs = {} # output_key_value -> (node_id, port_name) - for node in nodes_list: - if not isinstance(node, dict) or "id" not in node: - log.warning(f"[build_edges_from_nodes] 跳过无效节点: {node}") - continue - - run_cfg = node.get("config", {}).get("run", {}) - for key, value in run_cfg.items(): - if isinstance(key, str) and key.startswith("output") and isinstance(value, str): - produced_outputs[value] = (node["id"], key) - - # 2. 遍历节点,查找 input_key 引用的 output_key,生成边 - edges = [] - for node in nodes_list: - if not isinstance(node, dict) or "id" not in node: - continue - - run_cfg = node.get("config", {}).get("run", {}) - for key, value in run_cfg.items(): - if isinstance(key, str) and key.startswith("input") and isinstance(value, str): - if value in produced_outputs: - src_id, src_port = produced_outputs[value] - edges.append({ - "source": src_id, - "target": node["id"], - "source_port": src_port, - "target_port": key - }) - - # 3. 保存并返回 - graph = {"nodes": nodes_list, "edges": edges} - save_path = Path(save_path) - save_path.write_text(json.dumps(graph, indent=2, ensure_ascii=False), encoding="utf-8") - - log.info(f"[build_edges_from_nodes] 生成 {len(nodes_list)} 个节点,{len(edges)} 条边") - return graph - -# ----------------------------------------------------- # -# CLI 方便快速测试(免参数版) -# ----------------------------------------------------- # -if __name__ == "__main__": - import json - from pathlib import Path - import pprint - - PY_PATH = Path("") - - graph = parse_pipeline_file(PY_PATH) - - pprint.pprint(graph, width=120) - - OUT_PATH = PY_PATH.with_suffix(".json") - OUT_PATH.write_text(json.dumps(graph, indent=2, ensure_ascii=False), encoding="utf-8") - print(f"saved to {OUT_PATH}") - - - - - - - - - - - - - - - - - - - - - -# if __name__ == "__main__": - # test_ops = [ - # "SQLGenerator", - # "SQLExecutionFilter", - # "SQLComponentClassifier", - # ] - # result = pipeline_assembler( - # test_ops, - # cache_dir="./cache_local", - # llm_local=False, - # chat_api_url="", - # model_name="gpt-4o", - # file_path = " " - # ) - # code_str = result["pipe_code"] - # write_pipeline_file(code_str, file_name="my_recommend_pipeline.py", overwrite=True) - # print("Generated pipeline code written to my_recommend_pipeline.py") diff --git a/dataflow_agent/toolkits/ragtool/__init__.py b/dataflow_agent/toolkits/ragtool/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/dataflow_agent/toolkits/shape_libraries/aws4.md b/dataflow_agent/toolkits/shape_libraries/aws4.md deleted file mode 100644 index fe4b248..0000000 --- a/dataflow_agent/toolkits/shape_libraries/aws4.md +++ /dev/null @@ -1,1049 +0,0 @@ -# aws4 - -**Type:** mxgraph shapes -**Prefix:** `mxgraph.aws4` - -## Usage - -```xml - - - -``` - -For simple shapes use: `shape=mxgraph.aws4.{shape};fillColor=#232F3D;` - -## Shapes (1032) - -- `a1_instance` -- `access_analyzer` -- `action` -- `activate` -- `actuator` -- `ad_connector` -- `addon` -- `agent` -- `agent2` -- `alarm` -- `alert` -- `alexa_enabled_device` -- `alexa_for_business` -- `alexa_skill` -- `alexa_smart_home_skill` -- `alexa_voice_service` -- `all_products` -- `ami` -- `amplify` -- `amplify_aws_amplify_studio` -- `analytics` -- `apache_mxnet_on_aws` -- `api_gateway` -- `app_config` -- `app_mesh` -- `app_runner` -- `app_studio` -- `app_wizard` -- `appfabric` -- `appflow` -- `application` -- `application_auto_scaling` -- `application_composer` -- `application_cost_profiler` -- `application_discovery_service` -- `application_discovery_service_aws_agentless_collector` -- `application_discovery_service_aws_discovery_agent` -- `application_discovery_service_migration_evaluator_collector` -- `application_integration` -- `application_load_balancer` -- `application_recovery_controller` -- `apps` -- `appstream_20` -- `appsync` -- `ar_vr` -- `archive` -- `artifact` -- `athena` -- `athena_data_source_connectors` -- `attribute` -- `attributes` -- `audit_manager` -- `augmented_ai` -- `aurora` -- `aurora_instance` -- `aurora_instance_alt` -- `authenticated_user` -- `auto_scaling` -- `auto_scaling2` -- `auto_scaling3` -- `automation` -- `autoscaling` -- `aws_backup_for_aws_cloudformation` -- `aws_backup_legal_hold` -- `aws_backup_support_for_amazon_fsx_for_netapp_ontap` -- `aws_backup_vault_lock` -- `aws_backup_virtual_machine_monitor` -- `aws_cloud` -- `aws_glue_data_quality` -- `aws_glue_for_ray` -- `aws_user_notifications` -- `b2b_data_interchange` -- `backint_agent` -- `backup` -- `backup_audit_manager` -- `backup_aws_backup_support_for_amazon_s3` -- `backup_aws_backup_support_for_vmware_workloads` -- `backup_backup_plan` -- `backup_backup_restore` -- `backup_compliance_reporting` -- `backup_compute` -- `backup_database` -- `backup_gateway` -- `backup_plan` -- `backup_recovery_point_objective` -- `backup_recovery_time_objective` -- `backup_restore` -- `backup_storage` -- `backup_vault` -- `backup_virtual_machine` -- `backup_virtual_machine_monitor` -- `bank` -- `batch` -- `bedrock` -- `blockchain` -- `blockchain_resource` -- `bottlerocket` -- `braket` -- `braket_chandelier` -- `braket_chip` -- `braket_embedded_simulator` -- `braket_managed_simulator` -- `braket_noise_simulator` -- `braket_qpu` -- `braket_simulator` -- `braket_simulator_1` -- `braket_simulator_2` -- `braket_simulator_3` -- `braket_simulator_4` -- `braket_state_vector` -- `braket_tensor_network` -- `bucket` -- `bucket_with_objects` -- `budgets` -- `budgets_2` -- `business_application` -- `bycicle` -- `c4_instance` -- `c5_instance` -- `c5a` -- `c5ad` -- `c5d` -- `c5n_instance` -- `c6g_instance` -- `c6gd` -- `cache_node` -- `cached_volume` -- `camera` -- `camera2` -- `car` -- `cart` -- `certificate_manager` -- `certificate_manager_2` -- `certificate_manager_3` -- `change_set` -- `chat` -- `chatbot` -- `checklist` -- `checklist_cost` -- `checklist_fault_tolerant` -- `checklist_performance` -- `checklist_security` -- `chime` -- `chime_sdk` -- `classic_load_balancer` -- `clean_rooms` -- `client` -- `client_vpn` -- `cloud9` -- `cloud_control_api` -- `cloud_development_kit` -- `cloud_digital_interface` -- `cloud_directory` -- `cloud_extension_ros` -- `cloud_map` -- `cloud_map_resource` -- `cloud_wan` -- `cloud_wan_segment_network` -- `cloud_wan_transit_gateway_route_table_attachment` -- `cloud_wan_virtual_pop` -- `cloudendure_disaster_recovery` -- `cloudendure_migration` -- `cloudformation` -- `cloudfront` -- `cloudfront_functions` -- `cloudhsm` -- `cloudsearch` -- `cloudsearch2` -- `cloudshell` -- `cloudtrail` -- `cloudtrail_cloudtrail_lake` -- `cloudwatch` -- `cloudwatch_2` -- `cloudwatch_cross_account_observability` -- `cloudwatch_data_protection` -- `cloudwatch_evidently` -- `cloudwatch_logs` -- `cloudwatch_metrics_insights` -- `cloudwatch_rum` -- `cloudwatch_synthetics` -- `cluster` -- `codeartifact` -- `codebuild` -- `codecatalyst` -- `codecommit` -- `codedeploy` -- `codeguru` -- `codeguru_2` -- `codepipeline` -- `codestar` -- `codewhisperer` -- `coffee_pot` -- `cognito` -- `cold_storage` -- `command_line_interface` -- `comprehend` -- `comprehend_medical` -- `compute` -- `compute_optimizer` -- `config` -- `connect` -- `connector` -- `contact_center` -- `container_1` -- `container_2` -- `container_3` -- `container_registry_image` -- `containers` -- `control_tower` -- `corporate_data_center` -- `corporate_data_center2` -- `corretto` -- `cost_and_usage_report` -- `cost_explorer` -- `cost_management` -- `credentials` -- `custom_billing_manager` -- `custom_event_bus_resource` -- `customer_enablement` -- `customer_engagement` -- `customer_gateway` -- `d2_instance` -- `d3_instance` -- `d3en_instance` -- `data_encryption_key` -- `data_exchange` -- `data_exchange_for_apis` -- `data_lake_resource_icon` -- `data_pipeline` -- `data_set` -- `data_stream` -- `data_table` -- `data_transfer_terminal` -- `database` -- `database_migration_service` -- `database_migration_workflow_job` -- `datasync` -- `datasync_discovery` -- `datazone` -- `datazone_business_data_catalog` -- `datazone_data_portal` -- `datazone_data_projects` -- `db_instance` -- `db_instance_read_replica` -- `db_instance_standby` -- `db_on_instance` -- `db_on_instance2` -- `deadline_cloud` -- `deep_learning_amis` -- `deep_learning_containers` -- `deepcomposer` -- `deeplens` -- `deepracer` -- `default_event_bus_resource` -- `dense_compute_node` -- `dense_storage_node` -- `deployment` -- `deployments` -- `desired_state` -- `desktop_and_app_streaming` -- `detective` -- `developer_tools` -- `development_environment` -- `device_farm` -- `devops_guru` -- `devops_guru_insights` -- `direct_connect` -- `directory_service` -- `disk` -- `distro_for_opentelemetry` -- `document` -- `documentdb_elastic_clusters` -- `documentdb_with_mongodb_compatibility` -- `documents` -- `documents2` -- `documents3` -- `door_lock` -- `download_distribution` -- `dynamodb` -- `dynamodb_dax` -- `dynamodb_standard_access_table_class` -- `dynamodb_standard_infrequent_access_table_class` -- `dynamodb_stream` -- `ec2` -- `ec2_aws_microservice_extractor_for_net` -- `ec2_c6a_instance` -- `ec2_c6gn_instance` -- `ec2_c6i_instance` -- `ec2_c6in_instance` -- `ec2_c7g_instance` -- `ec2_c7gn_instance` -- `ec2_dl1_instance` -- `ec2_g5_instance` -- `ec2_g5g_instance` -- `ec2_hpc6a_instance` -- `ec2_hpc6id_instance` -- `ec2_i4i_instance` -- `ec2_im4gn_instance` -- `ec2_image_builder` -- `ec2_inf2_instance` -- `ec2_instance_contents` -- `ec2_is4gen_instance` -- `ec2_m1_mac_instance` -- `ec2_m6a_instance` -- `ec2_m6i_instance` -- `ec2_m6idn_instance` -- `ec2_m6in_instance` -- `ec2_p4de_instance` -- `ec2_r6a_instance` -- `ec2_r6i_instance` -- `ec2_r6idn_instance` -- `ec2_r6in_instance` -- `ec2_r7iz_instance` -- `ec2_trn1_instance` -- `ec2_vt1_instance` -- `ec2_x2gd_instance` -- `ec2_x2idn_instance` -- `ec2_x2iedn_instance` -- `ec2_x2iezn_instance` -- `echo` -- `ecr` -- `ecs` -- `ecs_anywhere` -- `ecs_copilot_cli` -- `ecs_service` -- `ecs_service_connect` -- `ecs_task` -- `edge_location` -- `efs_infrequentaccess` -- `efs_standard` -- `eks` -- `eks_anywhere` -- `eks_cloud` -- `eks_distro` -- `eks_on_outposts` -- `elastic_beanstalk` -- `elastic_block_store` -- `elastic_block_store_amazon_data_lifecycle_manager` -- `elastic_block_store_volume_gp3` -- `elastic_fabric_adapter` -- `elastic_file_system` -- `elastic_file_system_elastic_throughput` -- `elastic_file_system_infrequent_access` -- `elastic_file_system_intelligent_tiering` -- `elastic_file_system_one_zone` -- `elastic_file_system_one_zone_infrequent_access` -- `elastic_file_system_one_zone_standard` -- `elastic_file_system_standard` -- `elastic_file_system_standard_infrequent_access` -- `elastic_inference` -- `elastic_inference_2` -- `elastic_ip_address` -- `elastic_load_balancing` -- `elastic_network_adapter` -- `elastic_network_interface` -- `elastic_transcoder` -- `elastic_vmware_service` -- `elasticache` -- `elasticache_for_memcached` -- `elasticache_for_redis` -- `elasticache_for_valkey` -- `elasticsearch_service` -- `elemental` -- `elemental_link` -- `elemental_mediaconnect` -- `elemental_mediaconvert` -- `elemental_medialive` -- `elemental_mediapackage` -- `elemental_mediastore` -- `elemental_mediatailor` -- `email` -- `email_2` -- `email_notification` -- `emr` -- `emr_engine` -- `emr_engine_mapr_m3` -- `emr_engine_mapr_m5` -- `emr_engine_mapr_m7` -- `encrypted_data` -- `end_user_messaging` -- `endpoint` -- `endpoints` -- `entity_resolution` -- `event` -- `event_event_based` -- `event_resource` -- `event_time_based` -- `eventbridge` -- `eventbridge_custom_event_bus_resource` -- `eventbridge_default_event_bus_resource` -- `eventbridge_pipes` -- `eventbridge_saas_partner_event_bus_resource` -- `eventbridge_scheduler` -- `eventbridge_schema` -- `eventbridge_schema_registry` -- `express_workflow` -- `external_sdk` -- `external_toolkit` -- `f1_instance` -- `factory` -- `fargate` -- `fault_injection_simulator` -- `file_cache` -- `file_cache_hybrid_nfs_linked_datasets` -- `file_cache_on_premises_nfs_linked_datasets` -- `file_cache_s3_linked_datasets` -- `file_gateway` -- `file_system` -- `filtering_rule` -- `finding` -- `finspace` -- `firetv` -- `firetv_stick` -- `firewall_manager` -- `fleet_management` -- `flow_logs` -- `folder` -- `folders` -- `forecast` -- `forums` -- `fraud_detector` -- `freertos` -- `fsx` -- `fsx_file_gateway` -- `fsx_for_lustre` -- `fsx_for_netapp_ontap` -- `fsx_for_openzfs` -- `fsx_for_windows_file_server` -- `g3_instance` -- `g4ad_instance` -- `g4dn` -- `game_tech` -- `game_tech2` -- `gamekit` -- `gamelift` -- `gamelift_2` -- `gamelift_streams` -- `games` -- `gamesparks` -- `gateway` -- `gateway_load_balancer` -- `gear` -- `general` -- `general_access_points` -- `generic` -- `generic_application` -- `generic_database` -- `generic_firewall` -- `genomics_cli` -- `git_repository` -- `glacier` -- `glacier_deep_archive` -- `global_accelerator` -- `global_secondary_index` -- `globe` -- `glue` -- `glue_crawlers` -- `glue_data_catalog` -- `glue_databrew` -- `glue_elastic_views` -- `greengrass` -- `ground_station` -- `group_account` -- `group_auto_scaling_group` -- `group_availability_zone` -- `group_aws_cloud` -- `group_aws_cloud_alt` -- `group_aws_step_functions_workflow` -- `group_corporate_data_center` -- `group_ec2_instance_contents` -- `group_elastic_beanstalk` -- `group_elastic_load_balancing` -- `group_iot_greengrass` -- `group_iot_greengrass_deployment` -- `group_on_premise` -- `group_region` -- `group_security_group` -- `group_spot_fleet` -- `group_subnet` -- `group_vpc` -- `group_vpc2` -- `guardduty` -- `h1_instance` -- `habana_gaudi` -- `hardware_board` -- `hdfs_cluster` -- `healthimaging` -- `healthlake` -- `healthscribe` -- `high_memory_instance` -- `honeycode` -- `hosted_zone` -- `house` -- `http2_protocol` -- `http_notification` -- `http_protocol` -- `i2` -- `i3_instance` -- `i3en` -- `identity_access_management_iam_roles_anywhere` -- `identity_and_access_management` -- `illustration_desktop` -- `illustration_devices` -- `illustration_notification` -- `illustration_office_building` -- `illustration_users` -- `import_export` -- `inf1` -- `inferentia` -- `infrequent_access_storage_class` -- `inspector` -- `instance` -- `instance2` -- `instance_with_cloudwatch` -- `instance_with_cloudwatch2` -- `instances` -- `instances_2` -- `intelligent_tiering` -- `interactive_video` -- `internet` -- `internet_alt1` -- `internet_alt2` -- `internet_alt22` -- `internet_gateway` -- `internet_of_things` -- `inventory` -- `iot_1click` -- `iot_analytics` -- `iot_analytics_channel` -- `iot_analytics_data_store` -- `iot_analytics_dataset` -- `iot_analytics_pipeline` -- `iot_button` -- `iot_core` -- `iot_core_device_advisor` -- `iot_core_device_location` -- `iot_device_defender` -- `iot_device_defender_iot_device_jobs` -- `iot_device_gateway` -- `iot_device_jobs_resource` -- `iot_device_management` -- `iot_device_management_fleet` -- `iot_device_tester` -- `iot_edukit` -- `iot_events` -- `iot_expresslink` -- `iot_fleetwise` -- `iot_greengrass_artifact` -- `iot_greengrass_component` -- `iot_greengrass_component_machine_learning` -- `iot_greengrass_component_nucleus` -- `iot_greengrass_component_private` -- `iot_greengrass_component_public` -- `iot_greengrass_interprocess_communication` -- `iot_greengrass_protocol` -- `iot_greengrass_recipe` -- `iot_greengrass_stream_manager` -- `iot_lorawan_protocol` -- `iot_over_the_air_update` -- `iot_roborunner` -- `iot_sailboat` -- `iot_sitewise` -- `iot_sitewise_asset` -- `iot_sitewise_asset_hierarchy` -- `iot_sitewise_asset_model` -- `iot_sitewise_asset_properties` -- `iot_sitewise_data_streams` -- `iot_thing_freertos_device` -- `iot_thing_humidity_sensor` -- `iot_thing_industrial_pc` -- `iot_thing_plc` -- `iot_thing_relay` -- `iot_thing_stacklight` -- `iot_thing_temperature_humidity_sensor` -- `iot_thing_temperature_sensor` -- `iot_thing_temperature_vibration_sensor` -- `iot_thing_vibration_sensor` -- `iot_things_graph` -- `iot_twinmaker` -- `iq` -- `item` -- `items` -- `json_script` -- `kendra` -- `key_management_service` -- `key_management_service_external_key_store` -- `keyspaces` -- `kinesis` -- `kinesis_data_analytics` -- `kinesis_data_firehose` -- `kinesis_data_streams` -- `kinesis_video_streams` -- `lake_formation` -- `lambda` -- `lambda_function` -- `layers` -- `lex` -- `license_manager` -- `license_manager_application_discovery` -- `license_manager_license_blending` -- `lightbulb` -- `lightsail` -- `lightsail_for_research` -- `local_zones` -- `location_service` -- `location_service_geofence` -- `location_service_map` -- `location_service_place` -- `location_service_routes` -- `location_service_track` -- `logs` -- `long_term_security_credential` -- `lookout_for_equipment` -- `lookout_for_metrics` -- `lookout_for_vision` -- `lumberyard` -- `m4_instance` -- `m5_instance` -- `m5a_instance` -- `m5d_instance` -- `m5dn_instance` -- `m5n` -- `m5n_instance` -- `m5zn_instance` -- `m6g_instance` -- `m6gd_instance` -- `mac_instance` -- `machine_learning` -- `macie` -- `magnifying_glass` -- `magnifying_glass_2` -- `mainframe_modernization` -- `mainframe_modernization_analyzer` -- `mainframe_modernization_compiler` -- `mainframe_modernization_converter` -- `mainframe_modernization_developer` -- `mainframe_modernization_runtime` -- `maintenance_windows` -- `managed_apache_cassandra_service` -- `managed_blockchain` -- `managed_ms_ad` -- `managed_service_for_apache_flink` -- `managed_service_for_grafana` -- `managed_service_for_prometheus` -- `managed_services` -- `managed_streaming_for_kafka` -- `managed_workflows_for_apache_airflow` -- `management_and_governance` -- `management_console` -- `management_console2` -- `marketplace` -- `media_services` -- `mediaconnect_gateway` -- `medical_emergency` -- `memorydb_for_redis` -- `mesh` -- `message` -- `metrics` -- `mfa_token` -- `migration_and_transfer` -- `migration_evaluator` -- `migration_hub` -- `migration_hub_refactor_spaces_applications` -- `migration_hub_refactor_spaces_environments` -- `migration_hub_refactor_spaces_services` -- `mobile` -- `mobile_application` -- `mobile_client` -- `mobile_hub` -- `monitoring` -- `monitron` -- `mq` -- `mq_broker` -- `mqtt_protocol` -- `ms_sql_instance` -- `ms_sql_instance_alternate` -- `msk_amazon_msk_connect` -- `multimedia` -- `multiple_volumes_resource` -- `mxgraph.aws4` -- `mysql_db_instance` -- `mysql_db_instance_alternate` -- `namespace` -- `nat_gateway` -- `neptune` -- `network_access_control_list` -- `network_firewall` -- `network_firewall_endpoints` -- `network_load_balancer` -- `networking_and_content_delivery` -- `neuron_ml_sdk` -- `nice_dcv` -- `nice_enginframe` -- `nimble_studio` -- `nitro_enclaves` -- `non_cached_volume` -- `notebook` -- `nova` -- `nova2` -- `object` -- `office_building` -- `omics` -- `one_zone_ia` -- `open_3d_engine` -- `open_3d_engine_2` -- `opensearch_dashboards` -- `opensearch_ingestion` -- `opensearch_observability` -- `opensearch_service_cluster_administrator_node` -- `opensearch_service_data_node` -- `opensearch_service_index` -- `opensearch_service_traces` -- `opensearch_service_ultrawarm_node` -- `opsworks` -- `opsworks_apps` -- `opsworks_permissions` -- `optimized_instance` -- `oracle_database_at_aws` -- `oracle_db_instance` -- `oracle_db_instance_alternate` -- `organizations` -- `organizations_account` -- `organizations_account2` -- `organizations_management_account` -- `organizations_management_account2` -- `organizations_organizational_unit` -- `organizations_organizational_unit2` -- `outposts` -- `outposts_1u_and_2u_servers` -- `outposts_family` -- `p2_instance` -- `p3_instance` -- `p3dn_instance` -- `p4_instance` -- `p4d_instance` -- `panorama` -- `parallel_cluster` -- `parallel_computing_service` -- `parameter_store` -- `patch_manager` -- `payment_cryptography` -- `peering` -- `permissions` -- `permissions_2` -- `personal_health_dashboard` -- `personalize` -- `pinpoint` -- `pinpoint_journey` -- `police_emergency` -- `policy` -- `polly` -- `postgresql_instance` -- `private_5g` -- `private_certificate_authority` -- `privatelink` -- `professional_services` -- `programming_language` -- `proton` -- `q` -- `quantum_ledger_database` -- `quantum_technologies` -- `question` -- `queue` -- `quicksight` -- `quicksight_paginated_reports` -- `r4_instance` -- `r5_instance` -- `r5a_instance` -- `r5ad_instance` -- `r5b_instance` -- `r5d_instance` -- `r5gd_instance` -- `r5n` -- `r5n_instance` -- `r6g_instance` -- `rdn_instance` -- `rds` -- `rds_blue_green_deployments` -- `rds_instance` -- `rds_instance_alt` -- `rds_mariadb_instance` -- `rds_mariadb_instance_alt` -- `rds_multi_az` -- `rds_multi_az_db_cluster` -- `rds_mysql_instance` -- `rds_mysql_instance_alt` -- `rds_on_vmware` -- `rds_optimized_writes` -- `rds_oracle_instance` -- `rds_oracle_instance_alt` -- `rds_piop` -- `rds_piops` -- `rds_postgresql_instance` -- `rds_postgresql_instance_alt` -- `rds_proxy` -- `rds_proxy_alt` -- `rds_sql_server_instance` -- `rds_sql_server_instance_alt` -- `rds_trusted_language_extensions_for_postgresql` -- `recover` -- `red_hat_openshift` -- `redshift` -- `redshift_auto_copy` -- `redshift_data_sharing_governance` -- `redshift_ml` -- `redshift_query_editor_v20_light` -- `redshift_ra3` -- `redshift_streaming_ingestion` -- `registry` -- `rekognition` -- `rekognition_2` -- `rekognition_image` -- `rekognition_video` -- `replication` -- `replication_time_control` -- `reported_state` -- `repost` -- `repost_private` -- `rescue` -- `reserved_instance_reporting` -- `resilience_hub` -- `resource` -- `resource_access_manager` -- `resource_explorer` -- `resources` -- `robomaker` -- `robotics` -- `role` -- `route_53` -- `route_53_application_recovery_controller` -- `route_53_readiness_checks` -- `route_53_resolver` -- `route_53_resolver_dns_firewall` -- `route_53_resolver_query_logging` -- `route_53_routing_controls` -- `route_table` -- `router` -- `rule` -- `rule_2` -- `rule_3` -- `run_command` -- `s3` -- `s3_batch_operations` -- `s3_express_one_zone` -- `s3_file_gateway` -- `s3_multi_region_access_points` -- `s3_object_lambda` -- `s3_object_lambda_access_points` -- `s3_object_lock` -- `s3_on_outposts` -- `s3_on_outposts_storage` -- `s3_replication_time_control` -- `s3_select` -- `s3_storage_lens` -- `s3_tables` -- `s3_vectors` -- `saas_event_bus_resource` -- `sagemaker` -- `sagemaker_2` -- `sagemaker_canvas` -- `sagemaker_geospatial_ml` -- `sagemaker_ground_truth` -- `sagemaker_model` -- `sagemaker_notebook` -- `sagemaker_shadow_testing` -- `sagemaker_studio_lab` -- `sagemaker_train` -- `saml_token` -- `satellite` -- `savings_plans` -- `search_documents` -- `secrets_manager` -- `security_group` -- `security_hub` -- `security_hub_finding` -- `security_identity_and_compliance` -- `security_incident_response` -- `security_lake` -- `sensor` -- `server_migration_service` -- `serverless` -- `serverless_application_repository` -- `servers` -- `service` -- `service_catalog` -- `service_management_connector` -- `servo` -- `shadow` -- `shield` -- `shield2` -- `shield_shield_advanced` -- `signer` -- `simple_ad` -- `simple_email_service` -- `simple_storage_service_directory_bucket` -- `simple_storage_service_s3_glacier_instant_retrieval` -- `simspace_weaver` -- `simulation` -- `simulator` -- `single_sign_on` -- `site_to_site_vpn` -- `snapshot` -- `snowball` -- `snowball_edge` -- `snowcone` -- `snowmobile` -- `sns` -- `source_code` -- `spot_instance` -- `sql_primary` -- `sql_replica` -- `sql_workbench` -- `sqs` -- `ssl_padlock` -- `stack` -- `stack2` -- `standard_ia` -- `state_manager` -- `step_functions` -- `storage` -- `storage_gateway` -- `streaming_distribution` -- `sts` -- `sts_alternate` -- `sumerian` -- `supply_chain` -- `support` -- `systems_manager` -- `systems_manager_application_manager` -- `systems_manager_change_calendar` -- `systems_manager_change_manager` -- `systems_manager_compliance` -- `systems_manager_distributor` -- `systems_manager_incident_manager` -- `systems_manager_opscenter` -- `systems_manager_session_manager` -- `t2_instance` -- `t3_instance` -- `t3a_instance` -- `t4g_instance` -- `table` -- `tape_gateway` -- `tape_storage` -- `telco_network_builder` -- `template` -- `temporary_security_credential` -- `tensorflow_on_aws` -- `textract` -- `textract_analyze_lending` -- `thermostat` -- `thinkbox_deadline` -- `thinkbox_draft` -- `thinkbox_frost` -- `thinkbox_krakatoa` -- `thinkbox_sequoia` -- `thinkbox_stoke` -- `thinkbox_xmesh` -- `timestream` -- `tools_and_sdks` -- `topic` -- `topic_2` -- `torchserve` -- `traditional_server` -- `training_certification` -- `trainium_instance` -- `transcribe` -- `transfer_family` -- `transfer_family_aws_as2` -- `transfer_for_ftp_resource` -- `transfer_for_ftps_resource` -- `transfer_for_sftp` -- `transfer_for_sftp_resource` -- `transform` -- `transit_gateway` -- `transit_gateway_attachment` -- `translate` -- `travel` -- `trusted_advisor` -- `user` -- `user_notifications` -- `users` -- `utility` -- `vault` -- `verified_access` -- `verified_permissions` -- `virtual_gateway` -- `virtual_node` -- `virtual_private_cloud` -- `virtual_router` -- `virtual_service` -- `virtual_tape_library` -- `vmware_cloud_on_aws` -- `volume` -- `volume_gateway` -- `vpc` -- `vpc_access_points` -- `vpc_carrier_gateway` -- `vpc_lattice` -- `vpc_network_access_analyzer` -- `vpc_privatelink` -- `vpc_reachability_analyzer` -- `vpc_traffic_mirroring` -- `vpc_virtual_private_cloud_vpc` -- `vpn_connection` -- `vpn_gateway` -- `waf` -- `waf_bad_bot` -- `waf_bot` -- `waf_bot_control` -- `waf_labels` -- `waf_managed_rule` -- `waf_rule` -- `wavelength` -- `well_architect_tool` -- `well_architected_tool` -- `wickr` -- `windfarm` -- `work_package` -- `workdocs` -- `worklink` -- `workmail` -- `workspaces` -- `workspaces_family` -- `workspaces_family_amazon_workspaces` -- `workspaces_family_amazon_workspaces_core` -- `workspaces_thin_client` -- `workspaces_workspaces_web` -- `x1_instance` -- `x1_instance2` -- `x1e_instance` -- `xray` -- `z1d_instance` diff --git a/dataflow_agent/toolkits/shape_libraries/azure2.md b/dataflow_agent/toolkits/shape_libraries/azure2.md deleted file mode 100644 index ce1d46c..0000000 --- a/dataflow_agent/toolkits/shape_libraries/azure2.md +++ /dev/null @@ -1,431 +0,0 @@ -# azure2 - -**Type:** SVG images -**Path:** `img/lib/azure2/` - -## Usage - -```xml - - - -``` - -## Shapes (648) - -Shapes are organized by category: `azure2/{category}/{shape}.svg` - -### ai_machine_learning (30) - -- `AI_Studio` -- `Anomaly_Detector` -- `Azure_Applied_AI` -- `Azure_Experimentation_Studio` -- `Azure_Object_Understanding` -- `Azure_OpenAI` -- `Batch_AI` -- `Bonsai` -- `Bot_Services` -- `Cognitive_Services` -- `Cognitive_Services_Decisions` -- `Computer_Vision` -- `Content_Moderators` -- `Content_Safety` -- `Custom_Vision` -- `Face_APIs` -- `Form_Recognizers` -- `Genomics` -- `Immersive_Readers` -- `Language_Services` -- `Language_Understanding` -- `Machine_Learning` -- `Machine_Learning_Studio_Classic_Web_Services` -- `Machine_Learning_Studio_Web_Service_Plans` -- `Machine_Learning_Studio_Workspaces` -- `Personalizers` -- `QnA_Makers` -- `Serverless_Search` -- `Speech_Services` -- `Translator_Text` - -### analytics (14) - -- `Analysis_Services` -- `Azure_Databricks` -- `Azure_Synapse_Analytics` -- `Azure_Workbooks` -- `Data_Lake_Analytics` -- `Data_Lake_Store_Gen1` -- `Endpoint_Analytics` -- `Event_Hub_Clusters` -- `Event_Hubs` -- `HD_Insight_Clusters` -- `Log_Analytics_Workspaces` -- `Power_BI_Embedded` -- `Power_Platform` -- `Stream_Analytics_Jobs` - -### app_services (9) - -- `API_Management_Services` -- `App_Service_Certificates` -- `App_Service_Domains` -- `App_Service_Environments` -- `App_Service_Plans` -- `App_Services` -- `CDN_Profiles` -- `Notification_Hubs` -- `Search_Services` - -### compute (38) - -- `App_Services` -- `Application_Group` -- `Automanaged_VM` -- `Availability_Sets` -- `Azure_Compute_Galleries` -- `Azure_Spring_Cloud` -- `Batch_Accounts` -- `Cloud_Services_Classic` -- `Container_Instances` -- `Container_Services_Deprecated` -- `Disk_Encryption_Sets` -- `Disks` -- `Disks_Classic` -- `Disks_Snapshots` -- `Function_Apps` -- `Host_Groups` -- `Host_Pools` -- `Hosts` -- `Image_Definitions` -- `Image_Templates` -- `Image_Versions` -- `Images` -- `Kubernetes_Services` -- `Maintenance_Configuration` -- `Managed_Service_Fabric` -- `Mesh_Applications` -- `Metrics_Advisor` -- `OS_Images_Classic` -- `Restore_Points` -- `Restore_Points_Collections` -- `Service_Fabric_Clusters` -- `Shared_Image_Galleries` -- `VM_Images_Classic` -- `VM_Scale_Sets` -- `Virtual_Machine` -- `Virtual_Machines_Classic` -- `Workspaces` -- `Workspaces2` - -### containers (7) - -- `App_Services` -- `Azure_Red_Hat_OpenShift` -- `Batch_Accounts` -- `Container_Instances` -- `Container_Registries` -- `Kubernetes_Services` -- `Service_Fabric_Clusters` - -### databases (27) - -- `Azure_Cosmos_DB` -- `Azure_Data_Explorer_Clusters` -- `Azure_Database_MariaDB_Server` -- `Azure_Database_Migration_Services` -- `Azure_Database_MySQL_Server` -- `Azure_Database_PostgreSQL_Server` -- `Azure_Database_PostgreSQL_Server_Group` -- `Azure_Purview_Accounts` -- `Azure_SQL` -- `Azure_SQL_Edge` -- `Azure_SQL_Server_Stretch_Databases` -- `Azure_SQL_VM` -- `Azure_Synapse_Analytics` -- `Cache_Redis` -- `Data_Factory` -- `Elastic_Job_Agents` -- `Instance_Pools` -- `Managed_Database` -- `Oracle_Database` -- `SQL_Data_Warehouses` -- `SQL_Database` -- `SQL_Elastic_Pools` -- `SQL_Managed_Instance` -- `SQL_Server` -- `SQL_Server_Registries` -- `SSIS_Lift_And_Shift_IR` -- `Virtual_Clusters` - -### identity (35) - -- `AAD_Licenses` -- `Active_Directory_Connect_Health` -- `Active_Directory_Connect_Health2` -- `Administrative_Units` -- `App_Registrations` -- `Azure_AD_B2C` -- `Azure_AD_B2C2` -- `Azure_AD_Domain_Services` -- `Azure_AD_Identity_Protection` -- `Azure_AD_Privilege_Identity_Management` -- `Azure_Active_Directory` -- `Azure_Information_Protection` -- `Custom_Azure_AD_Roles` -- `Enterprise_Applications` -- `Entra_Connect` -- `Entra_Domain_Services` -- `Entra_Global_Secure_Access` -- `Entra_ID_Protection` -- `Entra_Internet_Access` -- `Entra_Managed_Identities` -- `Entra_Private_Access` -- `Entra_Privileged_Identity_Management` -- `Entra_Verified_ID` -- `External_Identities` -- `Groups` -- `Identity_Governance` -- `Managed_Identities` -- `Multi_Factor_Authentication` -- `PIM` -- `Security` -- `Tenant_Properties` -- `User_Settings` -- `Users` -- `Verifiable_Credentials` -- `Verification_As_A_Service` - -### networking (51) - -- `ATM_Multistack` -- `Application_Gateway_Containers` -- `Application_Gateways` -- `Azure_Communications_Gateway` -- `Azure_Firewall_Manager` -- `Azure_Firewall_Policy` -- `Bastions` -- `CDN_Profiles` -- `Connections` -- `DDoS_Protection_Plans` -- `DNS_Multistack` -- `DNS_Private_Resolver` -- `DNS_Security_Policy` -- `DNS_Zones` -- `ExpressRoute_Circuits` -- `Firewalls` -- `Front_Doors` -- `IP_Address_manager` -- `IP_Groups` -- `Load_Balancer_Hub` -- `Load_Balancers` -- `Local_Network_Gateways` -- `NAT` -- `Network_Interfaces` -- `Network_Security_Groups` -- `Network_Watcher` -- `On_Premises_Data_Gateways` -- `Private_Endpoint` -- `Private_Link` -- `Private_Link_Hub` -- `Private_Link_Service` -- `Proximity_Placement_Groups` -- `Public_IP_Addresses` -- `Public_IP_Addresses_Classic` -- `Public_IP_Prefixes` -- `Reserved_IP_Addresses_Classic` -- `Resource_Management_Private_Link` -- `Route_Filters` -- `Route_Tables` -- `Service_Endpoint_Policies` -- `Spot_VM` -- `Spot_VMSS` -- `Subnet` -- `Traffic_Manager_Profiles` -- `Virtual_Network_Gateways` -- `Virtual_Networks` -- `Virtual_Networks_Classic` -- `Virtual_Router` -- `Virtual_WAN_Hub` -- `Virtual_WANs` -- `Web_Application_Firewall_Policies_WAF` - -### security (14) - -- `Application_Security_Groups` -- `Azure_AD_Risky_Signins` -- `Azure_AD_Risky_Users` -- `Azure_Defender` -- `Azure_Sentinel` -- `Conditional_Access` -- `Detonation` -- `ExtendedSecurityUpdates` -- `Identity_Secure_Score` -- `Key_Vaults` -- `Keys` -- `MS_Defender_EASM` -- `Multifactor_Authentication` -- `Security_Center` - -### storage (17) - -- `Azure_Fileshare` -- `Azure_HCP_Cache` -- `Azure_NetApp_Files` -- `Azure_Stack_Edge` -- `Data_Box` -- `Data_Box_Edge` -- `Data_Lake_Storage_Gen1` -- `Data_Share_Invitations` -- `Data_Shares` -- `Import_Export_Jobs` -- `Recovery_Services_Vaults` -- `StorSimple_Data_Managers` -- `StorSimple_Device_Managers` -- `Storage_Accounts` -- `Storage_Accounts_Classic` -- `Storage_Explorer` -- `Storage_Sync_Services` - -### general (98) - -- `All_Resources` -- `Backlog` -- `Biz_Talk` -- `Blob_Block` -- `Blob_Page` -- `Branch` -- `Browser` -- `Bug` -- `Builds` -- `Cache` -- `Code` -- `Commit` -- `Controls` -- `Controls_Horizontal` -- `Cost_Alerts` -- `Cost_Analysis` -- `Cost_Budgets` -- `Cost_Management` -- `Cost_Management_and_Billing` -- `Counter` -- `Cubes` -- `Dashboard` -- `Dashboard2` -- `Dev_Console` -- `Download` -- `Error` -- `Extensions` -- `FTP` -- `File` -- `Files` -- `Folder_Blank` -- `Folder_Website` -- `Free_Services` -- `Gear` -- `Globe` -- `Globe_Error` -- `Globe_Success` -- `Globe_Warning` -- `Guide` -- `Heart` -- `Help_and_Support` -- `Image` -- `Information` -- `Input_Output` -- `Journey_Hub` -- `Launch_Portal` -- `Learn` -- `Load_Test` -- `Location` -- `Log_Streaming` -- `Management_Groups` -- `Management_Portal` -- `Marketplace` -- `Media` -- `Media_File` -- `Mobile` -- `Mobile_Engagement` -- `Module` -- `Power` -- `Power_Up` -- `Powershell` -- `Preview` -- `Preview_Features` -- `Process_Explorer` -- `Production_Ready_Database` -- `Quickstart_Center` -- `Recent` -- `Reservations` -- `Resource_Explorer` -- `Resource_Group_List` -- `Resource_Groups` -- `Resource_Linked` -- `SSD` -- `Scale` -- `Scheduler` -- `Search` -- `Search_Grid` -- `Server_Farm` -- `Service_Bus` -- `Service_Health` -- `Storage_Azure_Files` -- `Storage_Container` -- `Storage_Queue` -- `Subscriptions` -- `TFS_VC_Repository` -- `Table` -- `Tag` -- `Tags` -- `Templates` -- `Toolbox` -- `Troubleshoot` -- `Versions` -- `Web_Slots` -- `Web_Test` -- `Website_Power` -- `Website_Staging` -- `Workbooks` -- `Workflow` - -### other (149) - -(See draw.io for complete list of 149 shapes in the "other" category) - -Selected shapes: -- `Azure_Backup_Center` -- `Azure_Chaos_Studio` -- `Azure_Cloud_Shell` -- `Azure_Communication_Services` -- `Azure_Deployment_Environments` -- `Azure_Load_Testing` -- `Azure_Monitor_Dashboard` -- `Azure_Network_Manager` -- `Azure_Orbital` -- `Azure_Sphere` -- `Azure_Storage_Mover` -- `Grafana` -- `Kubernetes_Fleet_Manager` -- `SSH_Keys` - -### Additional Categories - -- **azure_ecosystem** (3): Applens, Azure_Hybrid_Center, Collaborative_Service -- **azure_stack** (8): Azure_Stack, Capacity, Infrastructure_Backup, Multi_Tenancy, Offers, Plans, Updates, User_Subscriptions -- **azure_vmware_solution** (1): AVS -- **blockchain** (6): ABS_Member, Azure_Blockchain_Service, Azure_Token_Service, Blockchain_Applications, Consortium, Outbound_Connection -- **cxp** (2): Elixir, Elixir_Purple -- **devops** (10): API_Connections, Application_Insights, Azure_DevOps, Change_Analysis, CloudTest, Code_Optimization, DevOps_Starter, DevTest_Labs, Lab_Accounts, Lab_Services -- **hybrid_multicloud** (5): Azure_Operator_5G_Core, Azure_Operator_Insights, Azure_Operator_Nexus, Azure_Operator_Service_Manager, Azure_Programmable_Connectivity -- **integration** (21): API_Management_Services, App_Configuration, Azure_API_for_FHIR, Azure_Data_Catalog, Event_Grid_Domains, Event_Grid_Subscriptions, Event_Grid_Topics, Integration_Accounts, Integration_Environments, Integration_Service_Environments, Logic_Apps, Logic_Apps_Custom_Connector, Partner_Namespace, Partner_Registration, Partner_Topic, Relays, SQL_Data_Warehouses, SendGrid_Accounts, Service_Bus, Software_as_a_Service, System_Topic -- **internet_of_things** (3): Digital_Twins, Logic_Apps, Time_Series_Insights_Access_Policies -- **intune** (17): Azure_AD_Roles_and_Administrators, Client_Apps, Device_Compliance, Device_Configuration, Device_Enrollment, Device_Security_Apple, Device_Security_Google, Device_Security_Windows, Devices, Exchange_Access, Intune, Intune_For_Education, Mindaro, Security_Baselines, Software_Updates, Tenant_Status, eBooks -- **iot** (19): Azure_IoT_Operations, Azure_Maps_Accounts, Azure_Stack_HCI_Sizer, Device_Provisioning_Services, Digital_Twins, Event_Hubs, Function_Apps, Industrial_IoT, IoT_Central_Applications, IoT_Edge, IoT_Hub, Logic_Apps, Notification_Hubs, Stack_HCI_Premium, Stream_Analytics_Jobs, Time_Series_Data_Sets, Time_Series_Insights_Environments, Time_Series_Insights_Event_Sources, Windows10_Core_Services -- **management_governance** (32): Activity_Log, Advisor, Alerts, Application_Insights, Arc_Machines, Automation_Accounts, Azure_Arc, Azure_Lighthouse, Blueprints, Compliance, Cost_Management_and_Billing, Customer_Lockbox_for_MS_Azure, Diagnostics_Settings, Education, Log_Analytics_Workspaces, MachinesAzureArc, Managed_Applications_Center, Managed_Desktop, Metrics, Monitor, My_Customers, Operation_Log_Classic, Policy, Recovery_Services_Vaults, Resource_Graph_Explorer, Resources_Provider, Scheduler_Job_Collections, Service_Catalog_MAD, Service_Providers, Solutions, Universal_Print, User_Privacy -- **menu** (1): Keys -- **migrate** (5): Azure_Migrate, Cost_Management_and_Billing, Data_Box, Data_Box_Edge, Recovery_Services_Vaults -- **mixed_reality** (2): Remote_Rendering, Spatial_Anchor_Accounts -- **monitor** (1): SAP_Azure_Monitor -- **power_platform** (9): AIBuilder, CopilotStudio, Dataverse, PowerApps, PowerAutomate, PowerBI, PowerFx, PowerPages, PowerPlatform -- **preview** (9): Azure_Cloud_Shell, Azure_Sphere, Azure_Workbooks, IoT_Edge, Private_Link_Hub, RTOS, Static_Apps, Time_Series_Data_Sets, Web_Environment -- **web** (5): API_Center, App_Space, Azure_Media_Service, Notification_Hub_Namespaces, SignalR diff --git a/dataflow_agent/toolkits/shape_libraries/gcp2.md b/dataflow_agent/toolkits/shape_libraries/gcp2.md deleted file mode 100644 index ceb032b..0000000 --- a/dataflow_agent/toolkits/shape_libraries/gcp2.md +++ /dev/null @@ -1,315 +0,0 @@ -# gcp2 - -**Type:** mxgraph shapes -**Prefix:** `mxgraph.gcp2` - -## Usage - -```xml - - - -``` - - - -## Shapes (298) - -- `a7_power` -- `admin_connected` -- `admob` -- `advanced_solutions_lab` -- `ai_hub` -- `anomaly_detection` -- `api_analytics` -- `api_monetization` -- `apigee_api_platform` -- `apigee_sense` -- `app_engine` -- `app_engine_icon` -- `application` -- `application_system` -- `arrow_cycle` -- `arrows_system` -- `aspect_ratio` -- `automl_natural_language` -- `automl_tables` -- `automl_translation` -- `automl_video_intelligence` -- `automl_vision` -- `avere` -- `beacon` -- `beyondcorp` -- `big_query` -- `bigquery` -- `biomedical_beaker` -- `biomedical_test_tube` -- `biomedical_trio` -- `blank` -- `blue_hexagon` -- `bucket` -- `bucket_scale` -- `calculator` -- `campaign_manager` -- `capabilities` -- `certified_industry_standard` -- `check` -- `check_2` -- `check_available` -- `check_scale` -- `circuit_board` -- `clock` -- `cloud` -- `cloud_apis` -- `cloud_armor` -- `cloud_automl` -- `cloud_bigtable` -- `cloud_cdn` -- `cloud_checkmark` -- `cloud_code` -- `cloud_composer` -- `cloud_computer` -- `cloud_connected_insight` -- `cloud_data_catalog` -- `cloud_data_fusion` -- `cloud_dataflow` -- `cloud_dataflow_icon` -- `cloud_datalab` -- `cloud_dataprep` -- `cloud_dataproc` -- `cloud_dataproc_icon` -- `cloud_datastore` -- `cloud_deployment_manager` -- `cloud_dns` -- `cloud_endpoints` -- `cloud_external_ip_addresses` -- `cloud_filestore` -- `cloud_firestore` -- `cloud_firewall_rules` -- `cloud_functions` -- `cloud_iam` -- `cloud_inference_api` -- `cloud_information` -- `cloud_iot_core` -- `cloud_iot_edge` -- `cloud_jobs_api` -- `cloud_load_balancing` -- `cloud_machine_learning` -- `cloud_memorystore` -- `cloud_messaging` -- `cloud_monitoring` -- `cloud_nat` -- `cloud_natural_language_api` -- `cloud_network` -- `cloud_pubsub` -- `cloud_router` -- `cloud_routes` -- `cloud_run` -- `cloud_scheduler` -- `cloud_security` -- `cloud_security_command_center` -- `cloud_security_scanner` -- `cloud_server` -- `cloud_service_mesh` -- `cloud_spanner` -- `cloud_speech_api` -- `cloud_sql` -- `cloud_storage` -- `cloud_sub_pub` -- `cloud_tasks` -- `cloud_test_lab` -- `cloud_text_to_speech` -- `cloud_tools_for_powershell` -- `cloud_tpu` -- `cloud_translation_api` -- `cloud_video_intelligence_api` -- `cloud_vision_api` -- `cloud_vpn` -- `cluster` -- `compute_engine` -- `compute_engine_2` -- `compute_engine_icon` -- `connected` -- `container_builder` -- `container_engine` -- `container_engine_icon` -- `container_optimized_os` -- `container_registry` -- `cost` -- `cost_arrows` -- `cost_savings` -- `data_access` -- `data_increase` -- `data_loss_prevention_api` -- `data_storage_cost` -- `data_studio` -- `database` -- `database_2` -- `database_3` -- `database_cycle` -- `database_speed` -- `database_uploading` -- `debugger` -- `dedicated_game_server` -- `dedicated_interconnect` -- `desktop` -- `desktop_and_mobile` -- `developer_portal` -- `dialogflow_enterprise_edition` -- `enhance_ui` -- `enhance_ui_2` -- `error_reporting` -- `external_data_center` -- `external_data_resource` -- `external_payment_form` -- `fastly` -- `files` -- `firebase` -- `folders` -- `forseti_lockup` -- `forseti_logo` -- `frontend_platform_services` -- `game` -- `gateway` -- `gateway_icon` -- `gear` -- `gear_arrow` -- `gear_chain` -- `gear_load` -- `genomics` -- `gke_on_prem` -- `globe_world` -- `google_ad_manager` -- `google_ads` -- `google_analytics` -- `google_analytics_360` -- `google_cloud_platform` -- `google_cloud_platform_lockup` -- `google_network` -- `google_network_edge_cache` -- `google_play_game_service` -- `gpu` -- `half_cloud` -- `https_load_balancer` -- `identity_aware_proxy` -- `image_services` -- `increase_cost_arrows` -- `internal_payment_authorization` -- `internet_connection` -- `istio_logo` -- `key` -- `key_management_service` -- `kubernetes_logo` -- `kubernetes_name` -- `laptop` -- `legacy_cloud` -- `legacy_cloud_2` -- `lifecycle` -- `lightbulb` -- `list` -- `live` -- `load_balancing` -- `loading` -- `loading_2` -- `loading_3` -- `lock` -- `logging` -- `logs_api` -- `management_security` -- `maps_api` -- `mem_instances` -- `memcache` -- `memory_card` -- `mobile_devices` -- `modifiers_autoscaling` -- `modifiers_custom_virtual_machine` -- `modifiers_high_cpu_machine` -- `modifiers_high_memory_machine` -- `modifiers_preemptable_vm` -- `modifiers_shared_core_machine_f1` -- `modifiers_shared_core_machine_g1` -- `modifiers_standard_machine` -- `modifiers_storage` -- `monitor` -- `monitor_2` -- `mxgraph.gcp2` -- `nat` -- `network` -- `network_load_balancer` -- `node` -- `outline_blank_1` -- `outline_blank_2` -- `outline_blank_3` -- `outline_highcomp` -- `outline_highmem` -- `partner_interconnect` -- `payment` -- `people_security_management` -- `persistent_disk` -- `persistent_disk_snapshot` -- `phone` -- `phone_android` -- `placeholder` -- `play_gear` -- `play_start` -- `prediction_api` -- `premium_network_tier` -- `primary` -- `process` -- `profiler` -- `push_notification_service` -- `recommendations_ai` -- `record` -- `replication_controller` -- `replication_controller_2` -- `replication_controller_3` -- `report` -- `repository` -- `repository_2` -- `repository_3` -- `repository_primary` -- `retail` -- `safety` -- `save` -- `scale` -- `scheduled_tasks` -- `search` -- `search_api` -- `security_key_enforcement` -- `segments` -- `segments_2` -- `segments_overlap` -- `servers_stacked` -- `service` -- `service_discovery` -- `social_media_time` -- `solution` -- `speaker` -- `speed` -- `squid_proxy` -- `stackdriver` -- `stacked_ownership` -- `standard_network_tier` -- `storage` -- `stream` -- `swap` -- `systems_check` -- `tape_record` -- `task_queues` -- `task_queues_2` -- `tensorflow_lockup` -- `tensorflow_logo` -- `thumbs_up` -- `time_clock` -- `trace` -- `traffic_director` -- `transfer_appliance` -- `users` -- `view_list` -- `virtual_file_system` -- `virtual_private_cloud` -- `visibility` -- `vpn` -- `vpn_gateway` -- `webcam` -- `website` diff --git a/dataflow_agent/toolkits/shape_libraries/kubernetes.md b/dataflow_agent/toolkits/shape_libraries/kubernetes.md deleted file mode 100644 index 61e7a40..0000000 --- a/dataflow_agent/toolkits/shape_libraries/kubernetes.md +++ /dev/null @@ -1,58 +0,0 @@ -# kubernetes - -**Type:** mxgraph shapes -**Prefix:** `mxgraph.kubernetes` - -## Usage - -```xml - - - -``` - - - -## Shapes (41) - -- `api` -- `c_c_m` -- `c_m` -- `c_role` -- `cm` -- `crb` -- `crd` -- `cronjob` -- `deploy` -- `ds` -- `ep` -- `etcd` -- `frame` -- `group` -- `hpa` -- `ing` -- `job` -- `k_proxy` -- `kubelet` -- `limits` -- `master` -- `mxgraph.kubernetes` -- `netpol` -- `node` -- `ns` -- `pod` -- `psp` -- `pv` -- `pvc` -- `quota` -- `rb` -- `role` -- `rs` -- `sa` -- `sc` -- `sched` -- `secret` -- `sts` -- `svc` -- `user` -- `vol` diff --git a/dataflow_agent/trajectory/__init__.py b/dataflow_agent/trajectory/__init__.py deleted file mode 100644 index 242478a..0000000 --- a/dataflow_agent/trajectory/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -DFA Trajectory (TRJ) - Workflow 执行轨迹导出模块 - -用于捕获、构建和导出 Workflow 执行过程数据,支持: -- ReAct 模式轨迹 -- Workflow 模式轨迹 -- 多模态数据 -""" - -from dataflow_agent.trajectory.models import ( - TrajectoryStep, - Trajectory, - TrajectoryMode, - StepRole, - ActionType, -) -from dataflow_agent.trajectory.collector import TrajectoryCollector -from dataflow_agent.trajectory.builder import TrajectoryBuilder -from dataflow_agent.trajectory.exporter import TrajectoryExporter -from dataflow_agent.trajectory.manager import TrajectoryManager - -__all__ = [ - # 数据模型 - "TrajectoryStep", - "Trajectory", - "TrajectoryMode", - "StepRole", - "ActionType", - # 核心组件 - "TrajectoryCollector", - "TrajectoryBuilder", - "TrajectoryExporter", - "TrajectoryManager", -] diff --git a/dataflow_agent/trajectory/builder.py b/dataflow_agent/trajectory/builder.py deleted file mode 100644 index 2ee39d2..0000000 --- a/dataflow_agent/trajectory/builder.py +++ /dev/null @@ -1,314 +0,0 @@ -""" -轨迹构建器 - 将收集的原始数据转换为标准 TRJ 格式 - -支持: -1. 从 State 和 Collector 构建完整轨迹 -2. 自动检测 ReAct/Workflow 模式 -3. 提取关键信息和统计数据 -""" - -from datetime import datetime -from typing import Any, Dict, List, Optional - -from dataflow_agent.trajectory.models import ( - Trajectory, - TrajectoryStep, - TrajectoryMode, - StepRole, -) -from dataflow_agent.trajectory.collector import TrajectoryCollector -from dataflow_agent.state import MainState -from dataflow_agent.logger import get_logger - -log = get_logger(__name__) - - -class TrajectoryBuilder: - """ - 轨迹构建器 - - 将 TrajectoryCollector 收集的原始步骤数据和 State 对象 - 转换为标准的 Trajectory 对象 - """ - - def __init__(self): - pass - - def build_from_state(self, - state: MainState, - collector: TrajectoryCollector, - workflow_name: str, - user_id: str = None, - session_id: str = None) -> Trajectory: - """ - 从 State 和 Collector 构建完整轨迹 - - Args: - state: Workflow 执行后的最终状态 - collector: 轨迹收集器 - workflow_name: Workflow 名称 - user_id: 用户 ID - session_id: 会话 ID - - Returns: - 完整的 Trajectory 对象 - """ - log.info(f"[TrajectoryBuilder] 开始构建轨迹: {workflow_name}") - - # 1. 生成 trace_id - trace_id = Trajectory.generate_trace_id() - - # 2. 获取步骤 - steps = collector.finish() - - # 3. 检测模式 - mode = self._detect_mode(state, steps) - - # 4. 提取输入 - inputs = self._extract_inputs(state, collector) - - # 5. 提取输出 - final_output = self._extract_final_output(state) - - # 6. 判断状态 - status = self._determine_status(state, steps) - - # 7. 计算统计信息 - total_duration_ms = self._calculate_total_duration(steps) - total_tokens = self._calculate_total_tokens(steps) - - # 8. 构建 Trajectory - trajectory = Trajectory( - trace_id=trace_id, - workflow_name=workflow_name, - timestamp=datetime.now().isoformat(), - status=status, - mode=mode, - user_id=user_id, - session_id=session_id or getattr(state.request, 'session_id', None), - inputs=inputs, - steps=steps, - final_output=final_output, - total_duration_ms=total_duration_ms, - total_tokens=total_tokens, - metadata=collector.get_metadata() - ) - - # 更新统计 - trajectory.total_llm_calls = sum(len(step.llm_calls) for step in steps) - trajectory.total_tool_calls = sum(len(step.tool_calls) for step in steps) - - log.info(f"[TrajectoryBuilder] 轨迹构建完成: {trace_id}, " - f"模式={mode}, 步骤数={len(steps)}, 状态={status}") - - return trajectory - - def build_from_steps(self, - steps: List[TrajectoryStep], - workflow_name: str, - inputs: Dict[str, Any] = None, - final_output: Any = None, - **kwargs) -> Trajectory: - """ - 直接从步骤列表构建轨迹(不依赖 State) - - Args: - steps: 步骤列表 - workflow_name: Workflow 名称 - inputs: 输入数据 - final_output: 最终输出 - **kwargs: 其他参数 - - Returns: - Trajectory 对象 - """ - trace_id = Trajectory.generate_trace_id() - mode = self._detect_mode_from_steps(steps) - status = "success" if not any(step.error for step in steps) else "failed" - - trajectory = Trajectory( - trace_id=trace_id, - workflow_name=workflow_name, - timestamp=datetime.now().isoformat(), - status=status, - mode=mode, - inputs=inputs or {}, - steps=steps, - final_output=final_output, - **kwargs - ) - - # 更新统计 - trajectory.total_llm_calls = sum(len(step.llm_calls) for step in steps) - trajectory.total_tool_calls = sum(len(step.tool_calls) for step in steps) - trajectory.total_duration_ms = self._calculate_total_duration(steps) - trajectory.total_tokens = self._calculate_total_tokens(steps) - - return trajectory - - def _detect_mode(self, state: MainState, steps: List[TrajectoryStep]) -> str: - """ - 检测轨迹模式 - - 通过分析 State 和 Steps 判断是 ReAct 还是 Workflow 模式 - """ - # 检查是否有 thought 字段(ReAct 特征) - has_thoughts = any(step.thought for step in steps) - - # 检查是否有 observation 字段(ReAct 特征) - has_observations = any(step.observation for step in steps) - - # 检查消息历史(ReAct 通常有更多的对话轮次) - messages = getattr(state, 'messages', []) - has_many_messages = len(messages) > 5 - - # 检查是否有明确的 agent 角色步骤 - has_agent_steps = any(step.role == StepRole.AGENT.value for step in steps) - - if has_thoughts or has_observations: - return TrajectoryMode.REACT.value - elif has_agent_steps and has_many_messages: - return TrajectoryMode.HYBRID.value - else: - return TrajectoryMode.WORKFLOW.value - - def _detect_mode_from_steps(self, steps: List[TrajectoryStep]) -> str: - """仅从步骤检测模式""" - has_thoughts = any(step.thought for step in steps) - has_observations = any(step.observation for step in steps) - - if has_thoughts or has_observations: - return TrajectoryMode.REACT.value - else: - return TrajectoryMode.WORKFLOW.value - - def _extract_inputs(self, state: MainState, collector: TrajectoryCollector) -> Dict[str, Any]: - """ - 提取输入数据 - - 优先级: - 1. Collector 记录的初始输入 - 2. State.request 中的字段 - 3. State 的其他相关字段 - """ - inputs = {} - - # 从 collector 获取 - collector_inputs = collector.get_initial_inputs() - if collector_inputs: - inputs.update(collector_inputs) - - # 从 state.request 提取 - if hasattr(state, 'request'): - request = state.request - - # 提取常见字段 - if hasattr(request, 'target') and request.target: - inputs['query'] = request.target - - if hasattr(request, 'model'): - inputs['model'] = request.model - - if hasattr(request, 'language'): - inputs['language'] = request.language - - # 提取其他可能的输入字段 - for field in ['json_file', 'python_file_path', 'keywords', 'style']: - if hasattr(request, field): - value = getattr(request, field) - if value: - inputs[field] = value - - return inputs - - def _extract_final_output(self, state: MainState) -> Any: - """ - 提取最终输出 - - 从 State 的不同字段中提取最终结果 - """ - # 尝试从 agent_results 获取 - if hasattr(state, 'agent_results') and state.agent_results: - # 获取最后一个 agent 的结果 - last_agent_result = None - for agent_name, result in state.agent_results.items(): - if isinstance(result, dict) and 'results' in result: - last_agent_result = result['results'] - - if last_agent_result: - return last_agent_result - - # 尝试从特定字段获取 - for field in ['final_output', 'execution_result', 'pipeline_structure_code', - 'recommendation', 'icon_prompt', 'research_summary']: - if hasattr(state, field): - value = getattr(state, field) - if value: - return value - - # 如果都没有,返回整个 state 的字典表示(简化版) - return {"status": "completed"} - - def _determine_status(self, state: MainState, steps: List[TrajectoryStep]) -> str: - """ - 判断执行状态 - - Returns: - "success" | "failed" | "partial" - """ - # 检查是否有错误步骤 - has_errors = any(step.error for step in steps) - - # 检查 execution_result - if hasattr(state, 'execution_result'): - exec_result = state.execution_result - if isinstance(exec_result, dict): - if exec_result.get('success') is False: - return "failed" - elif exec_result.get('success') is True: - return "success" - - # 根据错误情况判断 - if has_errors: - # 如果所有步骤都有错误,则失败 - if all(step.error for step in steps): - return "failed" - else: - return "partial" - - return "success" - - def _calculate_total_duration(self, steps: List[TrajectoryStep]) -> Optional[float]: - """计算总耗时""" - if not steps: - return None - - total = sum(step.duration_ms for step in steps if step.duration_ms) - return total if total > 0 else None - - def _calculate_total_tokens(self, steps: List[TrajectoryStep]) -> Optional[Dict[str, int]]: - """计算总 token 使用量""" - total_prompt = 0 - total_completion = 0 - - for step in steps: - for llm_call in step.llm_calls: - if llm_call.token_usage: - total_prompt += llm_call.token_usage.get('prompt', 0) - total_completion += llm_call.token_usage.get('completion', 0) - - if total_prompt > 0 or total_completion > 0: - return { - 'prompt': total_prompt, - 'completion': total_completion, - 'total': total_prompt + total_completion - } - - return None - - -# ==================== 便捷函数 ==================== - -def create_builder() -> TrajectoryBuilder: - """创建轨迹构建器""" - return TrajectoryBuilder() diff --git a/dataflow_agent/trajectory/collector.py b/dataflow_agent/trajectory/collector.py deleted file mode 100644 index 647f84d..0000000 --- a/dataflow_agent/trajectory/collector.py +++ /dev/null @@ -1,387 +0,0 @@ -""" -轨迹收集器 - 实时捕获 Workflow 执行过程数据 - -通过 Hook 机制在执行过程中收集: -- Agent 执行信息 -- LLM 调用记录 -- 工具调用记录 -- 状态变化 -""" - -import time -from datetime import datetime -from typing import Any, Dict, List, Optional -from dataclasses import dataclass, field - -from dataflow_agent.trajectory.models import ( - TrajectoryStep, - LLMCallRecord, - ToolCallRecord, - MultimodalData, - StepRole, - ActionType, -) -from dataflow_agent.logger import get_logger - -log = get_logger(__name__) - - -@dataclass -class StepContext: - """步骤上下文 - 用于跟踪当前正在执行的步骤""" - step_index: int - node_name: str - role: str - start_time: float - input_context: Dict[str, Any] = field(default_factory=dict) - llm_calls: List[LLMCallRecord] = field(default_factory=list) - tool_calls: List[ToolCallRecord] = field(default_factory=list) - thought: Optional[str] = None - action_type: Optional[str] = None - action_payload: Optional[Dict[str, Any]] = None - observation: Optional[str] = None - node_output: Optional[Dict[str, Any]] = None - multimodal_input: Optional[MultimodalData] = None - multimodal_output: Optional[MultimodalData] = None - error: Optional[str] = None - - -class TrajectoryCollector: - """ - 轨迹收集器 - - 使用方式: - 1. 在 Workflow 开始时调用 start() - 2. 在每个节点执行前调用 on_node_start() - 3. 在 LLM/工具调用时调用相应的 on_xxx 方法 - 4. 在每个节点执行后调用 on_node_end() - 5. 在 Workflow 结束时调用 finish() 获取所有步骤 - """ - - def __init__(self): - self.steps: List[TrajectoryStep] = [] - self.current_step: Optional[StepContext] = None - self.step_counter: int = 0 - self.start_time: Optional[float] = None - self.is_recording: bool = False - - # 初始输入 - self.initial_inputs: Dict[str, Any] = {} - - # 元数据 - self.metadata: Dict[str, Any] = {} - - def start(self, inputs: Dict[str, Any] = None, metadata: Dict[str, Any] = None): - """ - 开始记录 - - Args: - inputs: 初始输入数据 - metadata: 额外元数据 - """ - self.steps = [] - self.current_step = None - self.step_counter = 0 - self.start_time = time.time() - self.is_recording = True - self.initial_inputs = inputs or {} - self.metadata = metadata or {} - - log.info(f"[TrajectoryCollector] 开始记录轨迹") - - def finish(self) -> List[TrajectoryStep]: - """ - 结束记录并返回所有步骤 - - Returns: - 收集到的所有步骤 - """ - # 如果有未完成的步骤,先完成它 - if self.current_step: - self._finalize_current_step() - - self.is_recording = False - total_duration = (time.time() - self.start_time) * 1000 if self.start_time else 0 - - log.info(f"[TrajectoryCollector] 记录完成,共 {len(self.steps)} 个步骤," - f"总耗时 {total_duration:.2f}ms") - - return self.steps - - def on_node_start(self, - node_name: str, - role: str = StepRole.SYSTEM_NODE.value, - input_context: Dict[str, Any] = None): - """ - 节点开始执行 - - Args: - node_name: 节点名称 - role: 角色类型 - input_context: 输入上下文 - """ - if not self.is_recording: - return - - # 如果有未完成的步骤,先完成它 - if self.current_step: - self._finalize_current_step() - - # 创建新的步骤上下文 - self.current_step = StepContext( - step_index=self.step_counter, - node_name=node_name, - role=role, - start_time=time.time(), - input_context=input_context or {} - ) - - log.debug(f"[TrajectoryCollector] 节点开始: {node_name} (step {self.step_counter})") - - def on_node_end(self, - output: Dict[str, Any] = None, - error: str = None): - """ - 节点执行结束 - - Args: - output: 节点输出 - error: 错误信息 - """ - if not self.is_recording or not self.current_step: - return - - if output: - self.current_step.node_output = output - if error: - self.current_step.error = error - - self._finalize_current_step() - - log.debug(f"[TrajectoryCollector] 节点结束: {self.current_step.node_name if self.current_step else 'unknown'}") - - def on_llm_call(self, - model: str, - messages: List[Dict[str, Any]], - response: str, - duration_ms: float = None, - token_usage: Dict[str, int] = None, - temperature: float = None): - """ - 记录 LLM 调用 - - Args: - model: 模型名称 - messages: 输入消息 - response: 响应内容 - duration_ms: 耗时(毫秒) - token_usage: Token 使用量 - temperature: 温度参数 - """ - if not self.is_recording: - return - - record = LLMCallRecord( - model=model, - messages_in=messages, - response=response, - timestamp=datetime.now().isoformat(), - duration_ms=duration_ms, - token_usage=token_usage, - temperature=temperature - ) - - if self.current_step: - self.current_step.llm_calls.append(record) - else: - # 如果没有当前步骤,创建一个临时步骤 - log.warning("[TrajectoryCollector] LLM 调用发生在节点外部") - self.on_node_start("llm_call", StepRole.AGENT.value) - self.current_step.llm_calls.append(record) - - log.debug(f"[TrajectoryCollector] 记录 LLM 调用: {model}") - - def on_tool_call(self, - tool_name: str, - tool_args: Dict[str, Any], - tool_result: Any, - duration_ms: float = None, - error: str = None): - """ - 记录工具调用 - - Args: - tool_name: 工具名称 - tool_args: 工具参数 - tool_result: 工具结果 - duration_ms: 耗时(毫秒) - error: 错误信息 - """ - if not self.is_recording: - return - - record = ToolCallRecord( - tool_name=tool_name, - tool_args=tool_args, - tool_result=tool_result, - timestamp=datetime.now().isoformat(), - duration_ms=duration_ms, - error=error - ) - - if self.current_step: - self.current_step.tool_calls.append(record) - else: - log.warning("[TrajectoryCollector] 工具调用发生在节点外部") - self.on_node_start("tool_call", StepRole.TOOL.value) - self.current_step.tool_calls.append(record) - - log.debug(f"[TrajectoryCollector] 记录工具调用: {tool_name}") - - def on_thought(self, thought: str): - """ - 记录 Agent 的思考过程(ReAct 模式) - - Args: - thought: 思考内容 - """ - if not self.is_recording or not self.current_step: - return - - self.current_step.thought = thought - log.debug(f"[TrajectoryCollector] 记录思考: {thought[:50]}...") - - def on_action(self, - action_type: str, - action_payload: Dict[str, Any]): - """ - 记录 Agent 的动作(ReAct 模式) - - Args: - action_type: 动作类型 - action_payload: 动作内容 - """ - if not self.is_recording or not self.current_step: - return - - self.current_step.action_type = action_type - self.current_step.action_payload = action_payload - log.debug(f"[TrajectoryCollector] 记录动作: {action_type}") - - def on_observation(self, observation: str): - """ - 记录环境观察(ReAct 模式) - - Args: - observation: 观察内容 - """ - if not self.is_recording or not self.current_step: - return - - self.current_step.observation = observation - log.debug(f"[TrajectoryCollector] 记录观察: {observation[:50]}...") - - def on_multimodal_input(self, - data_type: str, - path: str = None, - url: str = None, - metadata: Dict[str, Any] = None): - """ - 记录多模态输入 - - Args: - data_type: 数据类型(image/audio/video) - path: 文件路径 - url: URL - metadata: 元数据 - """ - if not self.is_recording or not self.current_step: - return - - self.current_step.multimodal_input = MultimodalData( - type=data_type, - path=path, - url=url, - metadata=metadata or {} - ) - log.debug(f"[TrajectoryCollector] 记录多模态输入: {data_type}") - - def on_multimodal_output(self, - data_type: str, - path: str = None, - url: str = None, - metadata: Dict[str, Any] = None): - """ - 记录多模态输出 - - Args: - data_type: 数据类型 - path: 文件路径 - url: URL - metadata: 元数据 - """ - if not self.is_recording or not self.current_step: - return - - self.current_step.multimodal_output = MultimodalData( - type=data_type, - path=path, - url=url, - metadata=metadata or {} - ) - log.debug(f"[TrajectoryCollector] 记录多模态输出: {data_type}") - - def _finalize_current_step(self): - """完成当前步骤并添加到步骤列表""" - if not self.current_step: - return - - # 计算耗时 - duration_ms = (time.time() - self.current_step.start_time) * 1000 - - # 创建 TrajectoryStep - step = TrajectoryStep( - step_index=self.current_step.step_index, - node_name=self.current_step.node_name, - role=self.current_step.role, - timestamp=datetime.now().isoformat(), - input_context=self.current_step.input_context, - thought=self.current_step.thought, - action_type=self.current_step.action_type, - action_payload=self.current_step.action_payload, - observation=self.current_step.observation, - node_output=self.current_step.node_output, - llm_calls=self.current_step.llm_calls, - tool_calls=self.current_step.tool_calls, - multimodal_input=self.current_step.multimodal_input, - multimodal_output=self.current_step.multimodal_output, - error=self.current_step.error, - duration_ms=duration_ms - ) - - self.steps.append(step) - self.step_counter += 1 - self.current_step = None - - def get_current_step_index(self) -> int: - """获取当前步骤索引""" - return self.step_counter - - def get_steps_count(self) -> int: - """获取已完成的步骤数量""" - return len(self.steps) - - def get_initial_inputs(self) -> Dict[str, Any]: - """获取初始输入""" - return self.initial_inputs - - def get_metadata(self) -> Dict[str, Any]: - """获取元数据""" - return self.metadata - - -# ==================== 便捷函数 ==================== - -def create_collector() -> TrajectoryCollector: - """创建轨迹收集器""" - return TrajectoryCollector() diff --git a/dataflow_agent/trajectory/exporter.py b/dataflow_agent/trajectory/exporter.py deleted file mode 100644 index 7ea158a..0000000 --- a/dataflow_agent/trajectory/exporter.py +++ /dev/null @@ -1,354 +0,0 @@ -""" -轨迹导出器 - 将 TRJ 导出为不同格式 - -支持: -1. JSON 格式导出(单个/批量) -2. JSONL 格式导出(用于训练数据) -3. SFT/DPO 格式转换 -4. 数据库存储(可选) -""" - -import json -import os -from pathlib import Path -from typing import Any, Dict, List, Optional, Union -from datetime import datetime - -from dataflow_agent.trajectory.models import Trajectory -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root - -log = get_logger(__name__) - -PROJDIR = get_project_root() -DEFAULT_OUTPUT_DIR = PROJDIR / "outputs" / "trajectories" - - -class TrajectoryExporter: - """ - 轨迹导出器 - - 提供多种导出格式和存储方式 - """ - - def __init__(self, output_dir: Union[str, Path] = None): - """ - Args: - output_dir: 输出目录,默认为 outputs/trajectories - """ - self.output_dir = Path(output_dir) if output_dir else DEFAULT_OUTPUT_DIR - self.output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"[TrajectoryExporter] 输出目录: {self.output_dir}") - - def export_to_json(self, - trajectory: Trajectory, - filepath: str = None, - pretty: bool = True) -> str: - """ - 导出为 JSON 文件 - - Args: - trajectory: 轨迹对象 - filepath: 文件路径,如果为 None 则自动生成 - pretty: 是否格式化输出 - - Returns: - 保存的文件路径 - """ - if filepath is None: - filename = f"{trajectory.trace_id}.json" - filepath = self.output_dir / filename - else: - filepath = Path(filepath) - - # 确保目录存在 - filepath.parent.mkdir(parents=True, exist_ok=True) - - # 转换为字典 - data = trajectory.to_dict() - - # 写入文件 - with open(filepath, 'w', encoding='utf-8') as f: - if pretty: - json.dump(data, f, indent=2, ensure_ascii=False) - else: - json.dump(data, f, ensure_ascii=False) - - log.info(f"[TrajectoryExporter] 已导出 JSON: {filepath}") - return str(filepath) - - def export_to_jsonl(self, - trajectories: List[Trajectory], - filepath: str = None, - mode: str = "raw") -> str: - """ - 批量导出为 JSONL 文件(每行一个 JSON 对象) - - Args: - trajectories: 轨迹列表 - filepath: 文件路径 - mode: 导出模式 - - "raw": 完整的 TRJ 数据 - - "sft": SFT 训练格式 - - "dpo": DPO 训练格式 - - Returns: - 保存的文件路径 - """ - if filepath is None: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"trajectories_{mode}_{timestamp}.jsonl" - filepath = self.output_dir / filename - else: - filepath = Path(filepath) - - filepath.parent.mkdir(parents=True, exist_ok=True) - - with open(filepath, 'w', encoding='utf-8') as f: - for trj in trajectories: - if mode == "raw": - data = trj.to_dict() - elif mode == "sft": - data = { - "trace_id": trj.trace_id, - "messages": trj.to_sft_format(), - "metadata": { - "workflow": trj.workflow_name, - "status": trj.status, - } - } - elif mode == "dpo": - data = trj.to_dpo_format() - else: - raise ValueError(f"Unknown mode: {mode}") - - f.write(json.dumps(data, ensure_ascii=False) + '\n') - - log.info(f"[TrajectoryExporter] 已导出 JSONL ({mode}): {filepath}, " - f"共 {len(trajectories)} 条") - return str(filepath) - - def export_sft_dataset(self, - trajectories: List[Trajectory], - filepath: str = None, - filter_success: bool = True) -> str: - """ - 导出 SFT 训练数据集 - - Args: - trajectories: 轨迹列表 - filepath: 文件路径 - filter_success: 是否只保留成功的轨迹 - - Returns: - 保存的文件路径 - """ - # 过滤 - if filter_success: - trajectories = [t for t in trajectories if t.status == "success"] - log.info(f"[TrajectoryExporter] 过滤后保留 {len(trajectories)} 条成功轨迹") - - return self.export_to_jsonl(trajectories, filepath, mode="sft") - - def export_dpo_dataset(self, - chosen_trajectories: List[Trajectory], - rejected_trajectories: List[Trajectory], - filepath: str = None) -> str: - """ - 导出 DPO 训练数据集(成对数据) - - Args: - chosen_trajectories: 正例轨迹(成功的) - rejected_trajectories: 负例轨迹(失败的) - filepath: 文件路径 - - Returns: - 保存的文件路径 - """ - if filepath is None: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"dpo_pairs_{timestamp}.jsonl" - filepath = self.output_dir / filename - else: - filepath = Path(filepath) - - filepath.parent.mkdir(parents=True, exist_ok=True) - - # 按 prompt 分组 - prompt_groups = {} - - for trj in chosen_trajectories: - prompt = trj.inputs.get("query", "") - if prompt not in prompt_groups: - prompt_groups[prompt] = {"chosen": [], "rejected": []} - prompt_groups[prompt]["chosen"].append(trj) - - for trj in rejected_trajectories: - prompt = trj.inputs.get("query", "") - if prompt not in prompt_groups: - prompt_groups[prompt] = {"chosen": [], "rejected": []} - prompt_groups[prompt]["rejected"].append(trj) - - # 生成成对数据 - pairs_count = 0 - with open(filepath, 'w', encoding='utf-8') as f: - for prompt, group in prompt_groups.items(): - chosen_list = group.get("chosen", []) - rejected_list = group.get("rejected", []) - - # 每个 chosen 和每个 rejected 配对 - for chosen in chosen_list: - for rejected in rejected_list: - pair = { - "prompt": prompt, - "chosen": chosen.to_sft_format(), - "rejected": rejected.to_sft_format(), - "metadata": { - "chosen_trace_id": chosen.trace_id, - "rejected_trace_id": rejected.trace_id, - "chosen_score": chosen.feedback.score if chosen.feedback else None, - "rejected_score": rejected.feedback.score if rejected.feedback else None, - } - } - f.write(json.dumps(pair, ensure_ascii=False) + '\n') - pairs_count += 1 - - log.info(f"[TrajectoryExporter] 已导出 DPO 数据集: {filepath}, " - f"共 {pairs_count} 对") - return str(filepath) - - def export_statistics(self, - trajectories: List[Trajectory], - filepath: str = None) -> str: - """ - 导出统计信息 - - Args: - trajectories: 轨迹列表 - filepath: 文件路径 - - Returns: - 保存的文件路径 - """ - if filepath is None: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"statistics_{timestamp}.json" - filepath = self.output_dir / filename - else: - filepath = Path(filepath) - - # 计算统计信息 - stats = { - "total_trajectories": len(trajectories), - "by_status": {}, - "by_mode": {}, - "by_workflow": {}, - "total_steps": 0, - "total_llm_calls": 0, - "total_tool_calls": 0, - "total_duration_ms": 0, - "avg_steps_per_trajectory": 0, - "success_rate": 0, - } - - for trj in trajectories: - # 按状态统计 - status = trj.status - stats["by_status"][status] = stats["by_status"].get(status, 0) + 1 - - # 按模式统计 - mode = trj.mode - stats["by_mode"][mode] = stats["by_mode"].get(mode, 0) + 1 - - # 按 workflow 统计 - workflow = trj.workflow_name - stats["by_workflow"][workflow] = stats["by_workflow"].get(workflow, 0) + 1 - - # 累计统计 - stats["total_steps"] += len(trj.steps) - stats["total_llm_calls"] += trj.total_llm_calls - stats["total_tool_calls"] += trj.total_tool_calls - if trj.total_duration_ms: - stats["total_duration_ms"] += trj.total_duration_ms - - # 计算平均值 - if len(trajectories) > 0: - stats["avg_steps_per_trajectory"] = stats["total_steps"] / len(trajectories) - success_count = stats["by_status"].get("success", 0) - stats["success_rate"] = success_count / len(trajectories) - - # 保存 - with open(filepath, 'w', encoding='utf-8') as f: - json.dump(stats, f, indent=2, ensure_ascii=False) - - log.info(f"[TrajectoryExporter] 已导出统计信息: {filepath}") - return str(filepath) - - def load_from_json(self, filepath: str) -> Trajectory: - """ - 从 JSON 文件加载轨迹 - - Args: - filepath: 文件路径 - - Returns: - Trajectory 对象 - """ - with open(filepath, 'r', encoding='utf-8') as f: - data = json.load(f) - - # 这里需要实现从字典重建 Trajectory 的逻辑 - # 简化版本:直接返回字典 - log.info(f"[TrajectoryExporter] 已加载轨迹: {filepath}") - return data - - def load_from_jsonl(self, filepath: str) -> List[Dict[str, Any]]: - """ - 从 JSONL 文件加载轨迹列表 - - Args: - filepath: 文件路径 - - Returns: - 轨迹列表 - """ - trajectories = [] - with open(filepath, 'r', encoding='utf-8') as f: - for line in f: - if line.strip(): - data = json.loads(line) - trajectories.append(data) - - log.info(f"[TrajectoryExporter] 已加载 {len(trajectories)} 条轨迹: {filepath}") - return trajectories - - -# ==================== 便捷函数 ==================== - -def create_exporter(output_dir: str = None) -> TrajectoryExporter: - """创建轨迹导出器""" - return TrajectoryExporter(output_dir) - - -def quick_export(trajectory: Trajectory, - format: str = "json", - output_dir: str = None) -> str: - """ - 快速导出单个轨迹 - - Args: - trajectory: 轨迹对象 - format: 导出格式(json/jsonl) - output_dir: 输出目录 - - Returns: - 保存的文件路径 - """ - exporter = create_exporter(output_dir) - - if format == "json": - return exporter.export_to_json(trajectory) - elif format == "jsonl": - return exporter.export_to_jsonl([trajectory]) - else: - raise ValueError(f"Unknown format: {format}") diff --git a/dataflow_agent/trajectory/manager.py b/dataflow_agent/trajectory/manager.py deleted file mode 100644 index 32de18e..0000000 --- a/dataflow_agent/trajectory/manager.py +++ /dev/null @@ -1,308 +0,0 @@ -""" -轨迹管理器 - 统一的轨迹管理入口 - -提供简单易用的 API 来: -1. 开始/停止轨迹记录 -2. 自动构建和导出轨迹 -3. 批量管理轨迹 -""" - -from typing import Any, Dict, List, Optional, Union -from pathlib import Path - -from dataflow_agent.trajectory.models import Trajectory -from dataflow_agent.trajectory.collector import TrajectoryCollector -from dataflow_agent.trajectory.builder import TrajectoryBuilder -from dataflow_agent.trajectory.exporter import TrajectoryExporter -from dataflow_agent.state import MainState -from dataflow_agent.logger import get_logger - -log = get_logger(__name__) - - -class TrajectoryManager: - """ - 轨迹管理器 - 统一入口 - - 使用示例: - ```python - # 1. 创建管理器 - trj_manager = TrajectoryManager() - - # 2. 开始记录 - trj_manager.start_recording(inputs={"query": "..."}) - - # 3. 在 workflow 执行过程中,collector 会自动记录 - # (需要在 workflow 中集成 collector 的 hook) - - # 4. 停止记录并生成轨迹 - trajectory = trj_manager.stop_recording( - state=final_state, - workflow_name="my_workflow" - ) - - # 5. 导出 - filepath = trj_manager.export(trajectory, format="json") - ``` - """ - - def __init__(self, output_dir: str = None): - """ - Args: - output_dir: 导出目录 - """ - self.collector = TrajectoryCollector() - self.builder = TrajectoryBuilder() - self.exporter = TrajectoryExporter(output_dir) - - self.is_recording = False - self.current_trajectory: Optional[Trajectory] = None - - log.info("[TrajectoryManager] 初始化完成") - - def start_recording(self, - inputs: Dict[str, Any] = None, - metadata: Dict[str, Any] = None): - """ - 开始记录轨迹 - - Args: - inputs: 初始输入数据 - metadata: 额外元数据 - """ - self.collector.start(inputs=inputs, metadata=metadata) - self.is_recording = True - log.info("[TrajectoryManager] 开始记录轨迹") - - def stop_recording(self, - state: MainState, - workflow_name: str, - user_id: str = None, - session_id: str = None) -> Trajectory: - """ - 停止记录并生成轨迹 - - Args: - state: Workflow 执行后的最终状态 - workflow_name: Workflow 名称 - user_id: 用户 ID - session_id: 会话 ID - - Returns: - 生成的 Trajectory 对象 - """ - if not self.is_recording: - log.warning("[TrajectoryManager] 未在记录状态,无法停止") - return None - - # 构建轨迹 - trajectory = self.builder.build_from_state( - state=state, - collector=self.collector, - workflow_name=workflow_name, - user_id=user_id, - session_id=session_id - ) - - self.is_recording = False - self.current_trajectory = trajectory - - log.info(f"[TrajectoryManager] 轨迹记录完成: {trajectory.trace_id}") - return trajectory - - def export(self, - trajectory: Trajectory = None, - format: str = "json", - filepath: str = None, - **kwargs) -> str: - """ - 导出轨迹 - - Args: - trajectory: 要导出的轨迹,如果为 None 则使用当前轨迹 - format: 导出格式(json/jsonl/sft/dpo) - filepath: 文件路径 - **kwargs: 其他参数 - - Returns: - 保存的文件路径 - """ - if trajectory is None: - trajectory = self.current_trajectory - - if trajectory is None: - log.error("[TrajectoryManager] 没有可导出的轨迹") - return None - - if format == "json": - return self.exporter.export_to_json(trajectory, filepath, **kwargs) - elif format == "jsonl": - return self.exporter.export_to_jsonl([trajectory], filepath, **kwargs) - elif format == "sft": - return self.exporter.export_to_jsonl([trajectory], filepath, mode="sft") - elif format == "dpo": - return self.exporter.export_to_jsonl([trajectory], filepath, mode="dpo") - else: - raise ValueError(f"Unknown format: {format}") - - def export_batch(self, - trajectories: List[Trajectory], - format: str = "jsonl", - filepath: str = None, - **kwargs) -> str: - """ - 批量导出轨迹 - - Args: - trajectories: 轨迹列表 - format: 导出格式 - filepath: 文件路径 - **kwargs: 其他参数 - - Returns: - 保存的文件路径 - """ - if format == "jsonl": - return self.exporter.export_to_jsonl(trajectories, filepath, **kwargs) - elif format == "sft": - return self.exporter.export_sft_dataset(trajectories, filepath, **kwargs) - else: - raise ValueError(f"Batch export not supported for format: {format}") - - def get_collector(self) -> TrajectoryCollector: - """获取收集器实例(用于手动集成)""" - return self.collector - - def get_current_trajectory(self) -> Optional[Trajectory]: - """获取当前轨迹""" - return self.current_trajectory - - def add_feedback(self, - trajectory: Trajectory = None, - score: int = None, - comment: str = None, - edited_response: str = None, - labels: List[str] = None): - """ - 添加用户反馈 - - Args: - trajectory: 轨迹对象,如果为 None 则使用当前轨迹 - score: 评分 1-5 - comment: 评论 - edited_response: 用户修改后的回答 - labels: 标签列表 - """ - if trajectory is None: - trajectory = self.current_trajectory - - if trajectory is None: - log.error("[TrajectoryManager] 没有可添加反馈的轨迹") - return - - trajectory.set_feedback( - score=score, - comment=comment, - edited_response=edited_response, - labels=labels - ) - - log.info(f"[TrajectoryManager] 已添加反馈到轨迹: {trajectory.trace_id}") - - -# ==================== 全局单例 ==================== - -_global_manager: Optional[TrajectoryManager] = None - - -def get_trajectory_manager(output_dir: str = None) -> TrajectoryManager: - """ - 获取全局轨迹管理器(单例模式) - - Args: - output_dir: 输出目录 - - Returns: - TrajectoryManager 实例 - """ - global _global_manager - - if _global_manager is None: - _global_manager = TrajectoryManager(output_dir) - - return _global_manager - - -def reset_trajectory_manager(): - """重置全局轨迹管理器""" - global _global_manager - _global_manager = None - - -# ==================== 便捷函数 ==================== - -def quick_record(workflow_func, - workflow_name: str, - inputs: Dict[str, Any] = None, - export_format: str = "json", - **workflow_kwargs): - """ - 快速记录 workflow 执行轨迹的装饰器/函数 - - 使用示例: - ```python - # 作为装饰器 - @quick_record(workflow_name="my_workflow") - async def my_workflow(state): - # workflow 逻辑 - return final_state - - # 或作为函数 - final_state = await quick_record( - my_workflow, - workflow_name="my_workflow", - inputs={"query": "..."}, - state=initial_state - ) - ``` - """ - import asyncio - from functools import wraps - - # 如果是装饰器用法 - if callable(workflow_func): - @wraps(workflow_func) - async def wrapper(*args, **kwargs): - manager = get_trajectory_manager() - - # 开始记录 - manager.start_recording(inputs=inputs) - - try: - # 执行 workflow - if asyncio.iscoroutinefunction(workflow_func): - result = await workflow_func(*args, **kwargs) - else: - result = workflow_func(*args, **kwargs) - - # 停止记录 - trajectory = manager.stop_recording( - state=result, - workflow_name=workflow_name - ) - - # 导出 - filepath = manager.export(trajectory, format=export_format) - log.info(f"[quick_record] 轨迹已导出: {filepath}") - - return result - - except Exception as e: - log.exception(f"[quick_record] Workflow 执行失败: {e}") - raise - - return wrapper - - # 如果是函数调用用法 - else: - raise ValueError("quick_record 应该作为装饰器使用") diff --git a/dataflow_agent/trajectory/models.py b/dataflow_agent/trajectory/models.py deleted file mode 100644 index b9d2617..0000000 --- a/dataflow_agent/trajectory/models.py +++ /dev/null @@ -1,381 +0,0 @@ -""" -TRJ 数据模型定义 - -定义标准的轨迹数据结构,支持: -1. ReAct 模式:Context -> Thought -> Action -> Observation -2. Workflow 模式:State_In -> Node Processing -> State_Update -""" - -from dataclasses import dataclass, field -from datetime import datetime -from enum import Enum -from typing import Any, Dict, List, Optional -import uuid - - -class TrajectoryMode(str, Enum): - """轨迹模式""" - REACT = "react" - WORKFLOW = "workflow" - HYBRID = "hybrid" # 混合模式 - - -class StepRole(str, Enum): - """步骤角色""" - AGENT = "agent" - ENVIRONMENT = "environment" - SYSTEM_NODE = "system_node" - TOOL = "tool" - USER = "user" - - -class ActionType(str, Enum): - """动作类型""" - TOOL_CALL = "tool_call" - RESPONSE = "response" - STATE_UPDATE = "state_update" - LLM_CALL = "llm_call" - MULTIMODAL = "multimodal" - - -@dataclass -class ToolCallRecord: - """工具调用记录""" - tool_name: str - tool_args: Dict[str, Any] - tool_result: Any - timestamp: str - duration_ms: Optional[float] = None - error: Optional[str] = None - - -@dataclass -class LLMCallRecord: - """LLM 调用记录""" - model: str - messages_in: List[Dict[str, Any]] # 输入消息 - response: str # 输出响应 - timestamp: str - duration_ms: Optional[float] = None - token_usage: Optional[Dict[str, int]] = None # {"prompt": x, "completion": y} - temperature: Optional[float] = None - - -@dataclass -class MultimodalData: - """多模态数据""" - type: str # "image" | "audio" | "video" - path: Optional[str] = None # 文件路径 - url: Optional[str] = None # URL - base64: Optional[str] = None # Base64 编码(不推荐存储大数据) - metadata: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class TrajectoryStep: - """ - 单个执行步骤 - - 对于 ReAct 模式: - - input_context: Agent 看到的上下文 - - thought: Agent 的思考过程 - - action_type: 动作类型(tool_call/response) - - action_payload: 动作内容 - - observation: 环境反馈 - - 对于 Workflow 模式: - - input_context: 节点输入状态 - - node_output: 节点输出/状态更新 - """ - step_index: int - node_name: str - role: str # StepRole 的值 - timestamp: str - - # 输入上下文 - input_context: Dict[str, Any] = field(default_factory=dict) - - # ReAct 特有字段 - thought: Optional[str] = None - action_type: Optional[str] = None # ActionType 的值 - action_payload: Optional[Dict[str, Any]] = None - observation: Optional[str] = None - - # 通用输出 - node_output: Optional[Dict[str, Any]] = None - - # 详细记录 - llm_calls: List[LLMCallRecord] = field(default_factory=list) - tool_calls: List[ToolCallRecord] = field(default_factory=list) - - # 多模态数据 - multimodal_input: Optional[MultimodalData] = None - multimodal_output: Optional[MultimodalData] = None - - # 错误信息 - error: Optional[str] = None - - # 执行时间 - duration_ms: Optional[float] = None - - # 额外元数据 - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: - """转换为字典""" - result = { - "step_index": self.step_index, - "node_name": self.node_name, - "role": self.role, - "timestamp": self.timestamp, - } - - # 只添加非空字段 - if self.input_context: - result["input_context"] = self.input_context - if self.thought: - result["thought"] = self.thought - if self.action_type: - result["action_type"] = self.action_type - if self.action_payload: - result["action_payload"] = self.action_payload - if self.observation: - result["observation"] = self.observation - if self.node_output: - result["node_output"] = self.node_output - if self.llm_calls: - result["llm_calls"] = [self._llm_call_to_dict(c) for c in self.llm_calls] - if self.tool_calls: - result["tool_calls"] = [self._tool_call_to_dict(c) for c in self.tool_calls] - if self.multimodal_input: - result["multimodal_input"] = self._multimodal_to_dict(self.multimodal_input) - if self.multimodal_output: - result["multimodal_output"] = self._multimodal_to_dict(self.multimodal_output) - if self.error: - result["error"] = self.error - if self.duration_ms is not None: - result["duration_ms"] = self.duration_ms - if self.metadata: - result["metadata"] = self.metadata - - return result - - @staticmethod - def _llm_call_to_dict(call: LLMCallRecord) -> Dict[str, Any]: - return { - "model": call.model, - "messages_in": call.messages_in, - "response": call.response, - "timestamp": call.timestamp, - "duration_ms": call.duration_ms, - "token_usage": call.token_usage, - "temperature": call.temperature, - } - - @staticmethod - def _tool_call_to_dict(call: ToolCallRecord) -> Dict[str, Any]: - return { - "tool_name": call.tool_name, - "tool_args": call.tool_args, - "tool_result": call.tool_result, - "timestamp": call.timestamp, - "duration_ms": call.duration_ms, - "error": call.error, - } - - @staticmethod - def _multimodal_to_dict(data: MultimodalData) -> Dict[str, Any]: - result = {"type": data.type} - if data.path: - result["path"] = data.path - if data.url: - result["url"] = data.url - if data.metadata: - result["metadata"] = data.metadata - # 不导出 base64 以减小文件大小 - return result - - -@dataclass -class TrajectoryFeedback: - """用户反馈""" - score: Optional[int] = None # 1-5 评分 - comment: Optional[str] = None - edited_response: Optional[str] = None # 用户修改后的回答(用于 SFT) - labels: List[str] = field(default_factory=list) # 标签,如 ["good", "accurate"] - timestamp: Optional[str] = None - - -@dataclass -class Trajectory: - """ - 完整的执行轨迹 - - 包含三个核心部分: - 1. Metadata: 元数据 - 2. Steps: 执行步骤列表 - 3. Outcome: 最终结果和反馈 - """ - # ===== 元数据 ===== - trace_id: str - workflow_name: str - timestamp: str - status: str # "success" | "failed" | "partial" - mode: str # TrajectoryMode 的值 - - # 可选元数据 - user_id: Optional[str] = None - session_id: Optional[str] = None - version: str = "1.0" - - # ===== 输入 ===== - inputs: Dict[str, Any] = field(default_factory=dict) - - # ===== 执行步骤 ===== - steps: List[TrajectoryStep] = field(default_factory=list) - - # ===== 输出 ===== - final_output: Any = None - - # ===== 反馈 ===== - feedback: Optional[TrajectoryFeedback] = None - - # ===== 统计信息 ===== - total_duration_ms: Optional[float] = None - total_llm_calls: int = 0 - total_tool_calls: int = 0 - total_tokens: Optional[Dict[str, int]] = None - - # ===== 额外元数据 ===== - metadata: Dict[str, Any] = field(default_factory=dict) - - @staticmethod - def generate_trace_id() -> str: - """生成唯一的 trace_id""" - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - unique_id = uuid.uuid4().hex[:8] - return f"trj_{timestamp}_{unique_id}" - - def add_step(self, step: TrajectoryStep): - """添加执行步骤""" - self.steps.append(step) - # 更新统计 - self.total_llm_calls += len(step.llm_calls) - self.total_tool_calls += len(step.tool_calls) - - def set_feedback(self, score: int = None, comment: str = None, - edited_response: str = None, labels: List[str] = None): - """设置用户反馈""" - self.feedback = TrajectoryFeedback( - score=score, - comment=comment, - edited_response=edited_response, - labels=labels or [], - timestamp=datetime.now().isoformat() - ) - - def to_dict(self) -> Dict[str, Any]: - """转换为字典(用于 JSON 导出)""" - result = { - # 元数据 - "trace_id": self.trace_id, - "workflow_name": self.workflow_name, - "timestamp": self.timestamp, - "status": self.status, - "mode": self.mode, - "version": self.version, - - # 输入 - "inputs": self.inputs, - - # 步骤 - "steps": [step.to_dict() for step in self.steps], - - # 输出 - "final_output": self.final_output, - - # 统计 - "statistics": { - "total_steps": len(self.steps), - "total_llm_calls": self.total_llm_calls, - "total_tool_calls": self.total_tool_calls, - "total_duration_ms": self.total_duration_ms, - "total_tokens": self.total_tokens, - } - } - - # 可选字段 - if self.user_id: - result["user_id"] = self.user_id - if self.session_id: - result["session_id"] = self.session_id - if self.feedback: - result["feedback"] = { - "score": self.feedback.score, - "comment": self.feedback.comment, - "edited_response": self.feedback.edited_response, - "labels": self.feedback.labels, - "timestamp": self.feedback.timestamp, - } - if self.metadata: - result["metadata"] = self.metadata - - return result - - def to_sft_format(self) -> List[Dict[str, str]]: - """ - 转换为 SFT 训练格式(OpenAI messages 格式) - - Returns: - [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - """ - messages = [] - - for step in self.steps: - if step.role == StepRole.AGENT.value: - # Agent 的输出 - content = "" - if step.thought: - content += f"{step.thought}\n" - if step.action_type == ActionType.TOOL_CALL.value and step.action_payload: - tool_name = step.action_payload.get("tool_name", "") - tool_args = step.action_payload.get("tool_args", "") - content += f"{tool_name}({tool_args})" - elif step.node_output: - content += str(step.node_output) - - if content: - messages.append({"role": "assistant", "content": content}) - - elif step.role in [StepRole.ENVIRONMENT.value, StepRole.TOOL.value]: - # 工具/环境的输出 - if step.observation: - messages.append({"role": "tool", "content": step.observation}) - - elif step.role == StepRole.USER.value: - # 用户输入 - if step.input_context: - content = step.input_context.get("query", str(step.input_context)) - messages.append({"role": "user", "content": content}) - - return messages - - def to_dpo_format(self) -> Dict[str, Any]: - """ - 转换为 DPO 训练格式 - - Returns: - {"prompt": "...", "chosen": [...], "rejected": [...]} - """ - # 提取 prompt - prompt = self.inputs.get("query", self.inputs.get("target", "")) - - # 当前轨迹作为 chosen 或 rejected - trajectory_steps = self.to_sft_format() - - return { - "prompt": prompt, - "trajectory": trajectory_steps, - "score": self.feedback.score if self.feedback else None, - "status": self.status, - } diff --git a/dataflow_agent/workflow/wf_image2drawio.py b/dataflow_agent/workflow/wf_image2drawio.py deleted file mode 100644 index 4832512..0000000 --- a/dataflow_agent/workflow/wf_image2drawio.py +++ /dev/null @@ -1,606 +0,0 @@ -""" -image2drawio workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Convert a single diagram image into editable DrawIO XML. - -Pipeline: -1) OCR (VLM Qwen-VL-OCR preferred, fallback to PaddleOCR) -2) Generate no-text mask + inpainting (optional) -3) SAM segmentation on clean background -4) Shape classification + color sampling -5) Text assignment + image/icon extraction -6) DrawIO XML generation -""" - -from __future__ import annotations - -import asyncio -import base64 -import os -import time -from pathlib import Path -from typing import Any, Dict, List, Optional - -import cv2 -import numpy as np -from PIL import Image - -from dataflow_agent.workflow.registry import register -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.agentroles import create_vlm_agent - -from dataflow_agent.toolkits.multimodaltool.req_img import generate_or_edit_and_save_image_async -from dataflow_agent.toolkits.multimodaltool.sam_tool import segment_layout_boxes, segment_layout_boxes_server, free_sam_model -from dataflow_agent.toolkits.multimodaltool import ppt_tool -from dataflow_agent.toolkits.drawio_tools import wrap_xml -from dataflow_agent.toolkits.image2drawio import ( - classify_shape, - mask_to_bbox, - normalize_mask, - sample_fill_stroke, - save_masked_rgba, - bbox_iou_px, -) - -log = get_logger(__name__) - -TEXT_COLOR = "#111111" -TEXT_FONT_SIZE = 14 -TEXT_FONT_STYLE = 1 # draw.io fontStyle=1 => bold - - -def _ensure_result_path(state: Paper2FigureState) -> str: - raw = getattr(state, "result_path", None) - if raw: - return raw - root = get_project_root() - ts = int(time.time()) - base_dir = (root / "outputs" / "image2drawio" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - - -def _escape_xml(text: str) -> str: - if text is None: - return "" - return ( - text.replace("&", "&") - .replace("<", "<") - .replace(">", ">") - .replace('"', """) - ) - - -def _encode_image_base64(path: str) -> str: - with open(path, "rb") as f: - data = f.read() - return base64.b64encode(data).decode("utf-8") - - -def _build_mxcell( - cell_id: str, - value: str, - style: str, - bbox_px: List[int], - parent: str = "1", - vertex: bool = True, -) -> str: - x1, y1, x2, y2 = bbox_px - w = max(1, int(x2 - x1)) - h = max(1, int(y2 - y1)) - x = int(x1) - y = int(y1) - v_attr = "1" if vertex else "0" - return ( - f"" - f"" - f"" - ) - - -def _shape_style(shape_type: str, fill_hex: str, stroke_hex: str) -> str: - if shape_type == "ellipse": - base = "shape=ellipse;" - elif shape_type == "diamond": - base = "shape=rhombus;" - else: - base = "rounded=1;" if shape_type == "rounded_rect" else "rounded=0;" - return ( - f"{base}whiteSpace=wrap;html=1;align=center;verticalAlign=middle;" - f"fillColor={fill_hex};strokeColor={stroke_hex};" - f"fontColor={TEXT_COLOR};fontStyle={TEXT_FONT_STYLE};fontSize={TEXT_FONT_SIZE};" - ) - - -def _text_style(color_hex: str) -> str: - return ( - "text;html=1;align=center;verticalAlign=middle;whiteSpace=wrap;" - f"strokeColor=none;fillColor=none;fontColor={TEXT_COLOR};" - f"fontStyle={TEXT_FONT_STYLE};fontSize={TEXT_FONT_SIZE};" - ) - - -def _image_style(data_uri: str) -> str: - safe_uri = data_uri.replace(";", "%3B") - return f"shape=image;imageAspect=0;aspect=fixed;image={safe_uri};" - - -@register("image2drawio") -def create_image2drawio_graph() -> GenericGraphBuilder: - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - def _init_node(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - return state - - def _input_node(state: Paper2FigureState) -> Paper2FigureState: - req = getattr(state, "request", None) - if not req: - return state - img_path = getattr(req, "input_content", None) or getattr(req, "prev_image", None) - if img_path and os.path.exists(img_path): - state.fig_draft_path = img_path - else: - log.error(f"[image2drawio] input image not found: {img_path}") - return state - - async def _ocr_node(state: Paper2FigureState) -> Paper2FigureState: - """VLM OCR preferred; fallback to PaddleOCR.""" - img_path = state.fig_draft_path - if not img_path or not os.path.exists(img_path): - state.ocr_items = [] - return state - - ocr_items: List[Dict[str, Any]] = [] - api_key = getattr(state.request, "api_key", None) or getattr(state.request, "chat_api_key", None) - use_vlm = bool(getattr(state.request, "chat_api_url", None)) and bool(api_key) - if use_vlm: - try: - agent = create_vlm_agent( - name="ImageTextBBoxAgent", - model_name=getattr(state.request, "vlm_model", "qwen-vl-ocr-2025-11-20"), - chat_api_url=getattr(state.request, "chat_api_url", None), - vlm_mode="ocr", - additional_params={"input_image": img_path}, - ) - new_state = await agent.execute(state) - bbox_res = getattr(new_state, "bbox_result", []) - except Exception as e: - log.warning(f"[image2drawio][VLM] OCR failed: {e}") - bbox_res = [] - - # Normalize to px - try: - pil_img = Image.open(img_path) - w, h = pil_img.size - VLM_SCALE = 1000.0 - for it in bbox_res or []: - if "rotate_rect" in it and "bbox" not in it: - rr = it.get("rotate_rect") - if isinstance(rr, list) and len(rr) == 5: - cx, cy, rw, rh, angle = rr - rect = ((float(cx), float(cy)), (float(rw), float(rh)), float(angle)) - box = cv2.boxPoints(rect) - x_min = np.min(box[:, 0]) - x_max = np.max(box[:, 0]) - y_min = np.min(box[:, 1]) - y_max = np.max(box[:, 1]) - it["bbox"] = [ - max(0.0, min(1.0, y_min / VLM_SCALE)), - max(0.0, min(1.0, x_min / VLM_SCALE)), - max(0.0, min(1.0, y_max / VLM_SCALE)), - max(0.0, min(1.0, x_max / VLM_SCALE)), - ] - if "bbox" in it: - y1_n, x1_n, y2_n, x2_n = it["bbox"] - x1 = int(x1_n * w) - y1 = int(y1_n * h) - x2 = int(x2_n * w) - y2 = int(y2_n * h) - if x2 <= x1 or y2 <= y1: - continue - ocr_items.append({ - "text": it.get("text", "").strip(), - "bbox_px": [x1, y1, x2, y2], - }) - except Exception as e: - log.warning(f"[image2drawio][VLM] normalize failed: {e}") - ocr_items = [] - - # fallback to PaddleOCR if VLM unavailable or empty - if not ocr_items: - try: - res = ppt_tool.paddle_ocr_page_with_layout(img_path) - for bbox, text, _conf in res.get("lines", []): - if not bbox or not text: - continue - x1, y1, x2, y2 = [int(round(v)) for v in bbox] - if x2 <= x1 or y2 <= y1: - continue - ocr_items.append({ - "text": text.strip(), - "bbox_px": [x1, y1, x2, y2], - }) - except Exception as e: - log.error(f"[image2drawio][PaddleOCR] failed: {e}") - - # Build no_text image - try: - pil_img = Image.open(img_path).convert("RGB") - w, h = pil_img.size - mask_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) - for it in ocr_items: - x1, y1, x2, y2 = it["bbox_px"] - pad = 2 - x1 = max(0, x1 - pad) - y1 = max(0, y1 - pad) - x2 = min(w, x2 + pad) - y2 = min(h, y2 + pad) - cv2.rectangle(mask_img, (x1, y1), (x2, y2), (255, 255, 255), -1) - - base_dir = Path(_ensure_result_path(state)) - debug_dir = base_dir / "ocr_debug" - debug_dir.mkdir(parents=True, exist_ok=True) - no_text_path = debug_dir / "no_text.png" - cv2.imwrite(str(no_text_path), mask_img) - state.no_text_path = str(no_text_path) - except Exception as e: - log.warning(f"[image2drawio] no_text mask failed: {e}") - state.no_text_path = "" - - state.ocr_items = ocr_items - return state - - async def _inpainting_node(state: Paper2FigureState) -> Paper2FigureState: - img_path = state.fig_draft_path - no_text_path = getattr(state, "no_text_path", "") - base_dir = Path(_ensure_result_path(state)) - clean_bg_path = base_dir / "clean_bg.png" - - api_key = getattr(state.request, "api_key", None) or getattr(state.request, "chat_api_key", None) or os.getenv("DF_API_KEY") - api_url = getattr(state.request, "chat_api_url", None) - model_name = getattr(state.request, "gen_fig_model", None) - - if api_key and api_url and model_name and no_text_path and os.path.exists(no_text_path): - prompt = "Remove all text while keeping shapes, icons, and arrows. Do not change layout or colors." - try: - await generate_or_edit_and_save_image_async( - prompt=prompt, - save_path=str(clean_bg_path), - api_url=api_url, - api_key=api_key, - model=model_name, - use_edit=True, - image_path=no_text_path, - aspect_ratio=getattr(state, "aspect_ratio", "16:9"), - resolution="2K", - ) - except Exception as e: - log.warning(f"[image2drawio] inpainting failed: {e}") - - # fallback to no_text or original - if not clean_bg_path.exists(): - if no_text_path and os.path.exists(no_text_path): - try: - import shutil - shutil.copy(no_text_path, clean_bg_path) - except Exception: - pass - elif img_path and os.path.exists(img_path): - try: - import shutil - shutil.copy(img_path, clean_bg_path) - except Exception: - pass - - state.clean_bg_path = str(clean_bg_path) if clean_bg_path.exists() else "" - # Normalize clean_bg to original image size if needed - try: - if state.clean_bg_path and img_path and os.path.exists(state.clean_bg_path) and os.path.exists(img_path): - with Image.open(img_path) as orig_img: - orig_w, orig_h = orig_img.size - with Image.open(state.clean_bg_path) as bg_img: - bg_w, bg_h = bg_img.size - if orig_w and orig_h and (orig_w != bg_w or orig_h != bg_h): - resized = bg_img.resize((orig_w, orig_h), Image.LANCZOS) - resized.save(state.clean_bg_path) - except Exception as e: - log.warning(f"[image2drawio] resize clean_bg failed: {e}") - return state - - async def _sam_node(state: Paper2FigureState) -> Paper2FigureState: - img_path = getattr(state, "clean_bg_path", None) or state.fig_draft_path - if not img_path or not os.path.exists(img_path): - state.layout_items = [] - return state - - base_dir = Path(_ensure_result_path(state)) - out_dir = base_dir / "sam_items" - out_dir.mkdir(parents=True, exist_ok=True) - - sam_ckpt = f"{get_project_root()}/sam_b.pt" - # optional server URLs (set SAM_SERVER_URLS env, comma-separated) - sam_server_env = os.getenv("SAM_SERVER_URLS", "").strip() - sam_server_urls = [u.strip() for u in sam_server_env.split(",") if u.strip()] - layout_items: List[Dict[str, Any]] = [] - - if sam_server_urls: - try: - layout_items = segment_layout_boxes_server( - image_path=img_path, - output_dir=str(out_dir), - server_urls=sam_server_urls, - checkpoint=sam_ckpt, - min_area=120, - min_score=0.0, - iou_threshold=0.2, - top_k=None, - nms_by="mask", - ) - except Exception as e: - log.warning(f"[image2drawio] SAM server failed: {e}, fallback to local") - layout_items = [] - - if not layout_items: - try: - layout_items = segment_layout_boxes( - image_path=img_path, - output_dir=str(out_dir), - checkpoint=sam_ckpt, - min_area=120, - min_score=0.0, - iou_threshold=0.2, - top_k=None, - nms_by="mask", - ) - except Exception as e_local: - log.error(f"[image2drawio] SAM local failed: {e_local}") - layout_items = [] - finally: - try: - free_sam_model(checkpoint=sam_ckpt) - except Exception: - pass - - # compute bbox_px - try: - with Image.open(img_path) as tmp: - w, h = tmp.size - except Exception: - w, h = 1024, 1024 - - for it in layout_items: - bbox = it.get("bbox") - if bbox and len(bbox) == 4: - x1n, y1n, x2n, y2n = bbox - x1 = int(round(x1n * w)) - y1 = int(round(y1n * h)) - x2 = int(round(x2n * w)) - y2 = int(round(y2n * h)) - it["bbox_px"] = [x1, y1, x2, y2] - - state.layout_items = layout_items - return state - - async def _build_elements_node(state: Paper2FigureState) -> Paper2FigureState: - img_path = getattr(state, "clean_bg_path", None) or state.fig_draft_path - if not img_path or not os.path.exists(img_path): - state.drawio_elements = [] - return state - - image_bgr = cv2.imread(img_path) - if image_bgr is None: - state.drawio_elements = [] - return state - - base_dir = Path(_ensure_result_path(state)) - icon_dir = base_dir / "icons" - icon_dir.mkdir(parents=True, exist_ok=True) - - shapes = [] - images = [] - - # classify SAM items - for idx, it in enumerate(getattr(state, "layout_items", []) or []): - mask = it.get("mask") - bbox_px = it.get("bbox_px") - if mask is None or bbox_px is None: - if mask is None: - continue - try: - tmp_mask = normalize_mask(mask, image_bgr.shape[:2]) - bbox_px = mask_to_bbox(tmp_mask) - except Exception: - bbox_px = None - if bbox_px is None: - continue - - mask = normalize_mask(mask, image_bgr.shape[:2]) - shape_type, conf = classify_shape(mask) - - if shape_type != "unknown" and conf >= 0.8: - fill_hex, stroke_hex = sample_fill_stroke(image_bgr, mask) - shapes.append({ - "id": f"s{idx}", - "kind": "shape", - "shape_type": shape_type, - "bbox_px": bbox_px, - "fill": fill_hex, - "stroke": stroke_hex, - "text": "", - "area": it.get("area", 0), - }) - else: - out_path = icon_dir / f"icon_{idx}.png" - save_masked_rgba(image_bgr, mask, str(out_path)) - images.append({ - "id": f"i{idx}", - "kind": "image", - "bbox_px": bbox_px, - "image_path": str(out_path), - "area": it.get("area", 0), - }) - - # assign OCR text to shapes (scale OCR boxes to match clean_bg size if needed) - ocr_items = getattr(state, "ocr_items", []) or [] - try: - if state.fig_draft_path and os.path.exists(state.fig_draft_path): - with Image.open(state.fig_draft_path) as orig_img: - orig_w, orig_h = orig_img.size - else: - orig_w, orig_h = None, None - except Exception: - orig_w, orig_h = None, None - - if orig_w and orig_h: - tgt_h, tgt_w = image_bgr.shape[:2] - scale_x = tgt_w / float(orig_w) - scale_y = tgt_h / float(orig_h) - if abs(scale_x - 1.0) > 1e-3 or abs(scale_y - 1.0) > 1e-3: - scaled_items = [] - for it in ocr_items: - tb = it.get("bbox_px") - if not tb or len(tb) != 4: - continue - x1, y1, x2, y2 = tb - scaled_items.append({ - **it, - "bbox_px": [ - int(round(x1 * scale_x)), - int(round(y1 * scale_y)), - int(round(x2 * scale_x)), - int(round(y2 * scale_y)), - ], - }) - ocr_items = scaled_items - unassigned_text = [] - for t in ocr_items: - tb = t.get("bbox_px") - if not tb: - continue - cx = (tb[0] + tb[2]) * 0.5 - cy = (tb[1] + tb[3]) * 0.5 - best_iou = 0.0 - best_idx = -1 - for i, s in enumerate(shapes): - sb = s["bbox_px"] - if sb[0] <= cx <= sb[2] and sb[1] <= cy <= sb[3]: - iou = bbox_iou_px(tb, sb) - if iou > best_iou: - best_iou = iou - best_idx = i - if best_idx >= 0 and best_iou > 0.05: - text_val = t.get("text", "").strip() - if text_val: - if shapes[best_idx]["text"]: - shapes[best_idx]["text"] += "\n" + text_val - else: - shapes[best_idx]["text"] = text_val - else: - unassigned_text.append(t) - - texts = [] - for i, t in enumerate(unassigned_text): - tb = t.get("bbox_px") - if not tb: - continue - texts.append({ - "id": f"t{i}", - "kind": "text", - "bbox_px": tb, - "text": t.get("text", ""), - "color": TEXT_COLOR, - }) - - # sort elements by z (shapes large -> small, then images, then texts) - shapes.sort(key=lambda s: s.get("area", 0), reverse=True) - images.sort(key=lambda s: s.get("area", 0), reverse=True) - - state.drawio_elements = shapes + images + texts - return state - - async def _render_xml_node(state: Paper2FigureState) -> Paper2FigureState: - elements = getattr(state, "drawio_elements", []) or [] - clean_bg_path = getattr(state, "clean_bg_path", "") or "" - has_bg = bool(clean_bg_path and os.path.exists(clean_bg_path)) - if not elements and not has_bg: - state.drawio_xml = "" - return state - - cells = [] - id_counter = 2 - page_width = 850 - page_height = 1100 - - if has_bg: - try: - with Image.open(clean_bg_path) as bg_img: - bg_w, bg_h = bg_img.size - page_width = bg_w - page_height = bg_h - data_uri = "data:image/png;base64," + _encode_image_base64(clean_bg_path) - style = _image_style(data_uri) - cells.append(_build_mxcell(str(id_counter), "", style, [0, 0, bg_w, bg_h])) - id_counter += 1 - except Exception as e: - log.warning(f"[image2drawio] embed background failed: {e}") - - for el in elements: - if el.get("kind") == "shape": - style = _shape_style(el.get("shape_type", "rect"), el.get("fill", "#ffffff"), el.get("stroke", "#000000")) - value = el.get("text", "") - cells.append(_build_mxcell(str(id_counter), value, style, el["bbox_px"])) - id_counter += 1 - elif el.get("kind") == "image": - img_path = el.get("image_path") - if not img_path or not os.path.exists(img_path): - continue - data_uri = "data:image/png;base64," + _encode_image_base64(img_path) - style = _image_style(data_uri) - cells.append(_build_mxcell(str(id_counter), "", style, el["bbox_px"])) - id_counter += 1 - elif el.get("kind") == "text": - style = _text_style(el.get("color", "#000000")) - value = el.get("text", "") - cells.append(_build_mxcell(str(id_counter), value, style, el["bbox_px"])) - id_counter += 1 - - xml_cells = "\n".join(cells) - full_xml = wrap_xml(xml_cells, page_width=page_width, page_height=page_height) - - base_dir = Path(_ensure_result_path(state)) - out_path = base_dir / "image2drawio.drawio" - out_path.write_text(full_xml, encoding="utf-8") - - state.drawio_xml = full_xml - state.drawio_output_path = str(out_path) - return state - - nodes = { - "_start_": _init_node, - "input": _input_node, - "ocr": _ocr_node, - "inpainting": _inpainting_node, - "sam": _sam_node, - "build_elements": _build_elements_node, - "render_xml": _render_xml_node, - "_end_": lambda s: s, - } - - edges = [ - ("input", "ocr"), - ("ocr", "inpainting"), - ("inpainting", "sam"), - ("sam", "build_elements"), - ("build_elements", "render_xml"), - ("render_xml", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges) - builder.add_edge("_start_", "input") - return builder diff --git a/dataflow_agent/workflow/wf_image2ppt.py b/dataflow_agent/workflow/wf_image2ppt.py deleted file mode 100644 index ba60295..0000000 --- a/dataflow_agent/workflow/wf_image2ppt.py +++ /dev/null @@ -1,257 +0,0 @@ -""" -image2ppt workflow (VLM debug only) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -当前版本仅用于调试 VLM bbox: -1. Input: 单张图片 (FIGURE) -2. VLM: ImageTextBBoxAgent 提取文本 + bbox -3. Debug 输出: - - 在原图上画 bbox + 文本,保存 *_bbox_debug.png - - 把 bbox 区域刷成白色,保存 *_no_text.png -""" - -from __future__ import annotations -import os -import asyncio -from pathlib import Path -from typing import List, Dict, Any -import time -import copy - -import cv2 -import numpy as np -from PIL import Image - -from dataflow_agent.workflow.registry import register -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.utils import get_project_root -from dataflow_agent.agentroles import create_vlm_agent - -log = get_logger(__name__) - -# --- Helpers --- - -def _ensure_result_path(state: Paper2FigureState) -> str: - raw = getattr(state, "result_path", None) - if raw: - return raw - root = get_project_root() - ts = int(__import__("time").time()) - base_dir = (root / "outputs" / "image2ppt" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - -# --- Workflow Definition --- - -@register("image2ppt") -def create_image2ppt_graph() -> GenericGraphBuilder: - """ - 简化版 image2ppt,仅做 VLM bbox 可视化 + 去字图生成,不生成 PPT。 - """ - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - # 1. 初始化输出目录 - def _init_result_path(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - return state - - # 2. 输入处理:FIGURE -> slide_images - async def input_processing_node(state: Paper2FigureState) -> Paper2FigureState: - req = getattr(state, "request", None) - if not req: - log.warning("[image2ppt][input] state.request is None") - return state - - if req.input_type == "FIGURE": - img_path = req.input_content - if img_path and os.path.exists(img_path): - state.slide_images = [img_path] - else: - log.error(f"[image2ppt][input] FIGURE image not found: {img_path}") - elif isinstance(req.input_content, list): - state.slide_images = [p for p in req.input_content if os.path.exists(p)] - else: - log.warning("[image2ppt][input] unsupported input_type, expect FIGURE or list") - - if not getattr(state, "slide_images", []): - log.warning("[image2ppt][input] No valid slide images found in input.") - else: - log.info(f"[image2ppt][input] slide_images = {state.slide_images}") - return state - - # 3. 只跑 VLM,写 state.vlm_pages - async def vlm_only_node(state: Paper2FigureState) -> Paper2FigureState: - image_paths: List[str] = getattr(state, "slide_images", []) or [] - if not image_paths: - log.warning("[image2ppt][vlm] no slide_images, skip") - state.vlm_pages = [] - return state - - async def _process_single_image(page_idx: int, img_path: str) -> Dict[str, Any]: - try: - temp_state = copy.copy(state) - temp_state.result_path = state.result_path - - agent = create_vlm_agent( - name="ImageTextBBoxAgent", - model_name="qwen-vl-ocr-2025-11-20", - chat_api_url=getattr(state.request, "chat_api_url", None), - vlm_mode="understanding", - additional_params={ - "input_image": img_path - } - ) - - new_state = await agent.execute(temp_state) - bbox_res = getattr(new_state, "bbox_result", []) - log.info(f"[image2ppt][VLM] page#{page_idx+1} extracted {len(bbox_res)} text items") - return { - "page_idx": page_idx, - "path": img_path, - "vlm_data": bbox_res, - } - except Exception as e: - log.error(f"[image2ppt][VLM] page#{page_idx+1} failed: {e}") - return { - "page_idx": page_idx, - "path": img_path, - "vlm_data": [], - "error": str(e), - } - - tasks = [_process_single_image(i, p) for i, p in enumerate(image_paths)] - results = await asyncio.gather(*tasks) - - # 关键:直接写主 state.vlm_pages - state.vlm_pages = results - log.info(f"[image2ppt][VLM] state.vlm_pages len = {len(state.vlm_pages)}") - return state - - # 4. 画框 + 去字图 - async def debug_draw_and_mask_node(state: Paper2FigureState) -> Paper2FigureState: - vlm_pages = getattr(state, "vlm_pages", []) or [] - if not vlm_pages: - log.warning("[image2ppt][debug] No VLM pages, skip debug draw.") - return state - - base_dir = Path(_ensure_result_path(state)) - debug_dir = base_dir / "vlm_debug" - debug_dir.mkdir(parents=True, exist_ok=True) - - for page in vlm_pages: - page_idx = page.get("page_idx", 0) - img_path = page.get("path") - items = page.get("vlm_data") or [] - - log.info(f"[image2ppt][debug] page#{page_idx+1} img_path={img_path}, items={len(items)}") - - if not img_path or not os.path.exists(img_path): - log.warning(f"[image2ppt][debug] image not found for page#{page_idx+1}: {img_path}") - continue - - try: - pil_img = Image.open(img_path).convert("RGB") - w, h = pil_img.size - img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) - except Exception as e: - log.error(f"[image2ppt][debug] open image failed: {e}") - continue - - debug_img = img_cv.copy() - mask_img = img_cv.copy() - - # 用户要求去掉 1024 max 处理,直接使用原图尺寸 - model_process_w, model_process_h = w, h - - for it in items: - # 兼容 qwen-vl-ocr 的 rotate_rect -> bbox - if "rotate_rect" in it and "bbox" not in it: - try: - rr = it["rotate_rect"] - if isinstance(rr, list) and len(rr) == 5: - cx, cy, rw, rh, angle = rr - # cv2.boxPoints 需要 ((cx, cy), (w, h), angle) - rect = ((float(cx), float(cy)), (float(rw), float(rh)), float(angle)) - box = cv2.boxPoints(rect) - - x_min = np.min(box[:, 0]) - x_max = np.max(box[:, 0]) - y_min = np.min(box[:, 1]) - y_max = np.max(box[:, 1]) - - # 坐标是归一化到 0-1000 的,所以除以 1000 得到 0-1 的 bbox - VLM_SCALE = 1000.0 - it["bbox"] = [ - max(0.0, min(1.0, y_min / VLM_SCALE)), - max(0.0, min(1.0, x_min / VLM_SCALE)), - max(0.0, min(1.0, y_max / VLM_SCALE)), - max(0.0, min(1.0, x_max / VLM_SCALE)) - ] - except Exception as e: - log.warning(f"[image2ppt][debug] rotate_rect convert failed: {e}") - - bn = it.get("bbox") - txt = it.get("text", "") - if not bn or len(bn) != 4: - continue - - # VLM bbox: [ymin, xmin, ymax, xmax] in 0-1 - y1_n, x1_n, y2_n, x2_n = bn - x1 = int(x1_n * w) - y1 = int(y1_n * h) - x2 = int(x2_n * w) - y2 = int(y2_n * h) - - x1 = max(0, min(w - 1, x1)) - x2 = max(0, min(w, x2)) - y1 = max(0, min(h - 1, y1)) - y2 = max(0, min(h, y2)) - if x2 <= x1 or y2 <= y1: - continue - - cv2.rectangle(debug_img, (x1, y1), (x2, y2), (255, 0, 0), 2) - label = (txt or "")[:10].replace("\n", " ") - cv2.putText( - debug_img, - label, - (x1, max(0, y1 - 5)), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - (0, 0, 255), - 1, - cv2.LINE_AA, - ) - - mask_img[y1:y2, x1:x2] = (255, 255, 255) - - debug_path = debug_dir / f"page_{page_idx+1:03d}_bbox_debug.png" - no_text_path = debug_dir / f"page_{page_idx+1:03d}_no_text.png" - - cv2.imwrite(str(debug_path), debug_img) - cv2.imwrite(str(no_text_path), mask_img) - - log.info(f"[image2ppt][debug] Saved bbox debug: {debug_path}") - log.info(f"[image2ppt][debug] Saved no-text image: {no_text_path}") - - return state - - nodes = { - "_start_": _init_result_path, - "input_processing": input_processing_node, - "vlm_only": vlm_only_node, - "debug_draw_and_mask": debug_draw_and_mask_node, - "_end_": lambda s: s, - } - - edges = [ - ("input_processing", "vlm_only"), - ("vlm_only", "debug_draw_and_mask"), - ("debug_draw_and_mask", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges) - builder.add_edge("_start_", "input_processing") - return builder diff --git a/dataflow_agent/workflow/wf_paper2expfigure.py b/dataflow_agent/workflow/wf_paper2expfigure.py deleted file mode 100644 index 8558c06..0000000 --- a/dataflow_agent/workflow/wf_paper2expfigure.py +++ /dev/null @@ -1,1109 +0,0 @@ -""" -paper2expfigure workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -从 PDF 论文中提取表格并生成统计图的完整工作流 - -支持三种输入类型: -- PDF: 从 PDF 论文中提取表格 (完整流程) -- FIGURE: 直接输入表格图片 (跳过 PDF 解析和 MinerU) -- TEXT: 输入表格文本,先生成表格图片再处理 - -工作流程: -1. PDF → 图片 (pdf_to_images_node) -2. 图片 → MinerU 识别 (mineru_extract_node) -3. 提取表格数据 (table_extractor_node) -4. 提取论文核心思想 (paper_idea_extractor_node) - TEXT/FIGURE 模式跳过 -5. 智能推荐图表类型和生成代码 (code_executor_node) - - 调用 chart_type_recommender Agent 推荐图表类型 - - 调用 chart_code_generator Agent 生成 matplotlib 代码 - - 执行代码生成图表 -""" - -from __future__ import annotations -import os -import uuid -import json -from pathlib import Path -from typing import Dict, Any, List -import asyncio -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field -from functools import reduce -from pptx import Presentation -from pptx.util import Inches -from PIL import Image - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.workflow.registry import register -from dataflow_agent.agentroles import create_simple_agent -from dataflow_agent.agentroles.paper2any_agents.chart_type_recommender import create_chart_type_recommender -from dataflow_agent.agentroles.paper2any_agents.chart_code_generator import create_chart_code_generator -from dataflow_agent.toolkits.tool_manager import get_tool_manager -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import ( - pdf_to_pil_images, - extract_tables_from_mineru_results, - extract_text_from_mineru_results, - execute_matplotlib_code, -) -from dataflow_agent.toolkits.multimodaltool.mineru_tool import run_aio_two_step_extract -from dataflow_agent.toolkits.multimodaltool.req_img import generate_or_edit_and_save_image_async - - -log = get_logger(__name__) - - -@register("paper2expfigure") -def create_paper2expfigure_graph() -> GenericGraphBuilder: - """ - Paper2ExpFigure Workflow: 从 PDF 提取表格并生成统计图 - - 支持三种输入模式: - - PDF: state.paper_file (完整流程) - - FIGURE: state.fig_draft_path (表格图片,跳过 PDF 解析) - - TEXT: state.paper_idea (表格文本,先生成图片) - - 命令: dfa run --wf paper2expfigure - """ - builder = GenericGraphBuilder( - state_model=Paper2FigureState, - entry_point="_start_" - ) - - # ====================================================================== - # PRE-TOOLS: 为 Agent 提供输入数据 - # ====================================================================== - - @builder.pre_tool("paper_content", "paper_idea_extractor") - def _get_paper_content(state: Paper2FigureState) -> str: - """ - 从 MinerU 结果或 PDF 中提取文本内容,供 paper_idea_extractor 使用 - """ - # 优先从 MinerU 结果中提取 - log.critical("正在从 MinerU 结果中提取文本内容") - if hasattr(state, 'temp_data') and 'mineru_items' in state.temp_data: - mineru_items = state.temp_data.get('mineru_items', []) - if mineru_items: - text = extract_text_from_mineru_results(mineru_items, max_chars=15000) - if text: - return f"Paper content extracted from PDF:\n\n{text}" - - # 如果没有 MinerU 结果,直接从 PDF 读取(回退方案) - import fitz - pdf_path = state.paper_file - if not pdf_path or not os.path.exists(pdf_path): - log.warning("paper_file 为空或不存在,无法读取 PDF 内容") - return "" - - try: - doc = fitz.open(pdf_path) - text_parts = [] - # 读取前 10 页 - for page_idx in range(min(10, len(doc))): - page = doc.load_page(page_idx) - text_parts.append(page.get_text("text") or "") - doc.close() - - content = "\n".join(text_parts).strip() - log.info(f"[pre_tool] 从 PDF 直接提取了 {len(content)} 字符") - return f"Paper content from PDF:\n\n{content[:15000]}" - except Exception as e: - log.error(f"读取 PDF 失败: {e}") - return "" - - # ============================================================== - # 路由函数:根据输入类型决定流程 - # ============================================================== - - def _route_by_input_type(state: Paper2FigureState) -> str: - """ - 根据 input_type 决定下一个节点 - """ - input_type = getattr(state, 'input_type', None) or getattr(state.request, 'input_type', 'PDF') - input_type = str(input_type).upper() - - log.info(f"[_route_by_input_type] input_type = {input_type}") - - if input_type == "FIGURE": - return "figure_input_node" - elif input_type == "TEXT": - return "text_to_table_image_node" - else: # 默认 PDF - return "pdf_to_images_node" - - # ============================================================== - # NODES: 工作流节点 - # ============================================================== - - def _start_(state: Paper2FigureState) -> Paper2FigureState: - """起始节点:初始化""" - # 确保 temp_data 存在 - if not hasattr(state, 'temp_data') or state.temp_data is None: - state.temp_data = {} - - # 确保 result_path 存在且为绝对路径 - if not state.result_path: - output_dir = f"./outputs/paper2expfigure_{uuid.uuid4().hex[:8]}" - output_path = Path(output_dir) - output_path.mkdir(parents=True, exist_ok=True) - state.result_path = str(output_path.absolute()) - else: - # 转换为绝对路径 - state.result_path = str(Path(state.result_path).resolve()) - Path(state.result_path).mkdir(parents=True, exist_ok=True) - - # 从 request 同步 input_type 到 state - if hasattr(state.request, 'input_type'): - state.input_type = state.request.input_type - - log.info(f"[_start_] result_path = {state.result_path}") - log.info(f"[_start_] input_type = {getattr(state, 'input_type', 'PDF')}") - - return state - - async def figure_input_node(state: Paper2FigureState) -> Paper2FigureState: - """ - FIGURE 模式入口节点:处理直接输入的表格图片 - - 输入:state.fig_draft_path (单个图片路径或逗号分隔的多个路径) - 输出:构造与 MinerU 兼容的数据结构,直接进入 table_extractor_node - """ - log.info("[figure_input_node] 开始处理表格图片输入...") - - fig_path = state.fig_draft_path or "" - if not fig_path: - log.error("[figure_input_node] fig_draft_path 为空") - return state - - # 支持多个图片路径(逗号分隔) - image_paths = [p.strip() for p in fig_path.split(",") if p.strip()] - - output_path = Path(state.result_path) - table_images_dir = output_path / "table_images" - table_images_dir.mkdir(parents=True, exist_ok=True) - - # 构造表格数据结构(跳过 MinerU,直接构造 extracted_tables) - tables = [] - valid_image_paths = [] - - for idx, img_path in enumerate(image_paths): - img_path = Path(img_path) - if not img_path.exists(): - log.warning(f"[figure_input_node] 图片不存在: {img_path}") - continue - - # 复制图片到输出目录 - table_id = f"table_{idx}" - dest_path = table_images_dir / f"{table_id}.png" - - try: - img = Image.open(img_path) - img.save(dest_path) - valid_image_paths.append(str(dest_path)) - - # 构造表格信息(没有 MinerU 解析,headers/rows 为空) - tables.append({ - "table_id": table_id, - "headers": [], - "rows": [], - "caption": f"Table from image: {img_path.name}", - "bbox": [0, 0, 1, 1], - "content": "", - "image_path": str(dest_path), - "page_index": 0, - "page_number": 1, - }) - - log.info(f"[figure_input_node] 处理图片 {idx + 1}: {img_path} -> {dest_path}") - - except Exception as e: - log.error(f"[figure_input_node] 处理图片失败 ({img_path}): {e}") - - state.temp_data['image_paths'] = valid_image_paths - state.extracted_tables = tables - - # FIGURE 模式不需要提取 paper_idea,设置默认值 - if not state.paper_idea: - state.paper_idea = "Direct table image input - no paper context available" - - log.info(f"[figure_input_node] 完成,共处理 {len(tables)} 个表格图片") - return state - - async def text_to_table_image_node(state: Paper2FigureState) -> Paper2FigureState: - """ - TEXT 模式入口节点:将表格文本转换为表格图片 - - 输入:state.paper_idea (表格文本,支持 CSV/Markdown/纯文本/LaTeX 等格式) - 输出:通过 LLM 生成 matplotlib 代码渲染表格图片,支持多级表头等复杂结构 - - 支持多表格:自动识别并分割文本中的多个表格,按 table_0, table_1... 命名 - """ - from dataflow_agent.agentroles.paper2any_agents.table_text_renderer import ( - render_table_from_text, - split_tables_from_text, - ) - - log.info("[text_to_table_image_node] 开始处理表格文本输入...") - - table_text = state.paper_idea or "" - if not table_text: - log.error("[text_to_table_image_node] paper_idea (表格文本) 为空") - return state - - output_path = Path(state.result_path).resolve() - table_images_dir = output_path / "table_images" - table_images_dir.mkdir(parents=True, exist_ok=True) - - # 先分割多个表格 - log.info("[text_to_table_image_node] 分析文本中的表格...") - table_segments = await split_tables_from_text( - text=table_text, - state=state, - model_name=state.request.model or "deepseek-v3.2", - ) - - log.info(f"[text_to_table_image_node] 识别到 {len(table_segments)} 个表格") - - tables = [] - valid_image_paths = [] - - # 循环处理每个表格 - for idx, segment in enumerate(table_segments): - table_id = f"table_{idx}" - img_path = (table_images_dir / f"{table_id}.png").resolve() - segment_text = segment.get("text", "") - caption = segment.get("caption", "") - - if not segment_text.strip(): - log.warning(f"[text_to_table_image_node] 表格 {table_id} 文本为空,跳过") - continue - - log.info(f"[text_to_table_image_node] 处理表格 {idx + 1}/{len(table_segments)}: {table_id}") - - try: - # 使用 table_text_renderer agent 渲染表格 - success, parsed_data = await render_table_from_text( - table_text=segment_text, - output_path=img_path, - state=state, - title=caption, - model_name=state.request.model or "deepseek-v3.2", - ) - - if success: - valid_image_paths.append(str(img_path)) - - tables.append({ - "table_id": table_id, - "headers": parsed_data.get("headers", []), - "rows": parsed_data.get("rows", []), - "caption": caption, - "bbox": [0, 0, 1, 1], - "content": segment_text, - "image_path": str(img_path), - "page_index": 0, - "page_number": 1, - "has_multi_level_header": parsed_data.get("has_multi_level_header", False), - }) - - log.info(f"[text_to_table_image_node] 生成表格图片: {img_path}") - else: - log.warning(f"[text_to_table_image_node] 表格 {table_id} 渲染失败") - - except Exception as e: - log.error(f"[text_to_table_image_node] 生成表格图片失败: {e}") - import traceback - traceback.print_exc() - - state.temp_data['image_paths'] = valid_image_paths - state.extracted_tables = tables - - # TEXT 模式保留原始文本作为 paper_idea 的补充 - state.paper_idea = f"Table data from text input:\n{table_text}" - - log.info(f"[text_to_table_image_node] 完成,共生成 {len(tables)} 个表格图片") - return state - - async def pdf_to_images_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 节点 1: PDF → 图片 - 将 PDF 的每一页转换为 PIL Image 对象,保存到临时目录 - """ - pdf_path = Path(state.paper_file) - if not pdf_path.exists(): - log.error(f"PDF 文件不存在: {pdf_path}") - return state - - log.info(f"[pdf_to_images] 开始转换 PDF: {pdf_path}") - - # 转换 PDF 为图片 - images = pdf_to_pil_images(pdf_path, dpi=150) - - # 创建临时目录保存图片 - output_path = Path(state.result_path) - images_dir = output_path / "images" - images_dir.mkdir(exist_ok=True) - - # 保存图片 - image_paths = [] - for idx, img in enumerate(images): - img_path = images_dir / f"page_{idx+1}.png" - img.save(img_path) - image_paths.append(str(img_path)) - log.info(f"[pdf_to_images] 保存第 {idx+1} 页: {img_path}") - - # 存储到 state(使用绝对路径) - state.temp_data['image_paths'] = image_paths - - log.info(f"[pdf_to_images] 完成,共转换 {len(images)} 页") - return state - - async def mineru_extract_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 节点 2: MinerU 识别 - 使用 MinerU HTTP API 识别图片中的文本和表格 - """ - image_paths = state.temp_data.get('image_paths', []) - if not image_paths: - log.warning("[mineru_extract] 没有图片路径,跳过") - return state - - output_path = Path(state.result_path) - mineru_dir = output_path / "mineru_results" - mineru_dir.mkdir(exist_ok=True) - - port = state.mineru_port - all_items = [] - - # 对每一页图片执行 MinerU 识别 - for idx, img_path in enumerate(image_paths, 1): - log.info(f"[mineru_extract] 处理图片 {idx}/{len(image_paths)}: {img_path}") - try: - # 使用 run_aio_two_step_extract 进行识别 - items = await run_aio_two_step_extract( - image_path=str(img_path), - port=port, - ) - - # items 是一个列表,直接扩展到 all_items - if isinstance(items, list): - # 为每个 item 添加页面信息 - for item in items: - item['page_index'] = idx - 1 # 0-based index - item['page_number'] = idx # 1-based number - - all_items.extend(items) - log.info(f"[mineru_extract] 从 page_{idx} 提取了 {len(items)} 个元素") - - # 保存每页的识别结果为 JSON 文件(便于调试) - result_file = mineru_dir / f"page_{idx}_result.json" - with open(result_file, 'w', encoding='utf-8') as f: - json.dump(items, f, ensure_ascii=False, indent=2) - log.debug(f"[mineru_extract] 保存结果到: {result_file}") - else: - log.warning(f"[mineru_extract] 返回结果不是列表: {type(items)}") - - except Exception as e: - log.error(f"[mineru_extract] MinerU 识别失败 (page_{idx}): {e}") - - # 存储结果 - state.temp_data['mineru_items'] = all_items - - # 保存所有结果的汇总文件 - if all_items: - summary_file = mineru_dir / "all_results.json" - with open(summary_file, 'w', encoding='utf-8') as f: - json.dump(all_items, f, ensure_ascii=False, indent=2) - log.info(f"[mineru_extract] 保存汇总结果到: {summary_file}") - - log.info(f"[mineru_extract] 完成,共提取 {len(all_items)} 个元素") - return state - - async def table_extractor_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 节点 3: 提取表格 - 从 MinerU 识别结果中提取表格数据,并保存表格区域图片 - - 注意:FIGURE/TEXT 模式会跳过此节点(已在入口节点处理) - """ - # 如果已经有 extracted_tables(FIGURE/TEXT 模式),跳过 - if state.extracted_tables: - log.info(f"[table_extractor] 已有 {len(state.extracted_tables)} 个表格,跳过提取") - return state - - mineru_items = state.temp_data.get('mineru_items', []) - if not mineru_items: - log.warning("[table_extractor] 没有 MinerU 结果,跳过") - return state - - log.info("[table_extractor] 开始提取表格...") - tables = extract_tables_from_mineru_results(mineru_items, min_rows=2, min_cols=2) - - log.info(f"[table_extractor] 提取了 {len(tables)} 个表格") - - # 打印表格摘要 - for table in tables: - log.info(f" - {table['table_id']}: {len(table['headers'])} 列 x {len(table['rows'])} 行") - - # 保存表格区域图片 - if tables: - output_path = Path(state.result_path) - table_images_dir = output_path / "table_images" - table_images_dir.mkdir(exist_ok=True) - - # 获取原始图片路径 - image_paths = state.temp_data.get('image_paths', []) - - # 为每个表格裁剪并保存图片 - saved_count = 0 - for item in mineru_items: - if item.get('type') != 'table': - continue - - bbox = item.get('bbox', []) - if len(bbox) != 4: - continue - - # 找到对应的表格(通过 bbox 匹配) - table_match = None - for table in tables: - if table.get('bbox') == bbox: - table_match = table - break - - if not table_match: - continue - - # 从 item 中直接获取页面索引(在 mineru_extract_node 中添加的) - page_idx = item.get('page_index') - - if page_idx is not None and page_idx < len(image_paths): - try: - # 读取原始图片 - img_path = image_paths[page_idx] - img = Image.open(img_path) - - # bbox 是归一化坐标 [x0, y0, x1, y1],范围 0-1 - img_width, img_height = img.size - x0 = int(bbox[0] * img_width) - y0 = int(bbox[1] * img_height) - x1 = int(bbox[2] * img_width) - y1 = int(bbox[3] * img_height) - - # 裁剪表格区域 - table_img = img.crop((x0, y0, x1, y1)) - - # 保存图片 - table_id = table_match['table_id'] - page_num = item.get('page_number', page_idx + 1) - table_img_path = table_images_dir / f"{table_id}_page{page_num}.png" - table_img.save(table_img_path) - - # 将图片路径添加到 table 信息中 - table_match['image_path'] = str(table_img_path) - table_match['page_index'] = page_idx - table_match['page_number'] = page_num - - saved_count += 1 - log.info(f"[table_extractor] 保存表格图片: {table_img_path}") - - except Exception as e: - log.error(f"[table_extractor] 裁剪表格图片失败 ({table_match.get('table_id', 'unknown')}): {e}") - import traceback - traceback.print_exc() - - log.info(f"[table_extractor] 共保存了 {saved_count} 个表格图片到: {table_images_dir}") - - state.extracted_tables = tables - - return state - - async def paper_idea_extractor(state: Paper2FigureState) -> Paper2FigureState: - """ - 节点 4: 提取论文核心思想 - 调用 paper_idea_extractor Agent 从论文中提取核心思想 - - 注意:FIGURE/TEXT 模式会跳过此节点 - """ - # 检查是否需要跳过(FIGURE/TEXT 模式已设置 paper_idea) - input_type = getattr(state, 'input_type', None) or getattr(state.request, 'input_type', 'PDF') - if input_type in ['FIGURE', 'TEXT']: - log.info(f"[paper_idea_extractor] {input_type} 模式,跳过论文思想提取") - return state - - log.info("[paper_idea_extractor] 开始提取论文核心思想...") - - agent = create_simple_agent( - name="paper_idea_extractor", - model_name=getattr(state.request, "chart_model", "deepseek-v3.2"), - temperature=0.1, - max_tokens=4096, - parser_type="json", - ) - - state = await agent.execute(state=state) - - paper_idea = state.paper_idea or "" - log.info(f"[paper_idea_extractor] 提取的核心思想长度: {len(paper_idea)} 字符") - log.info(f"[paper_idea_extractor] 核心思想预览: {paper_idea[:200]}...") - - return state - - async def chart_type_recommender_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 节点 5: 智能推荐图表类型 - 调用 chart_type_recommender Agent 智能推荐图表类型 - """ - log.info("[chart_type_recommender] 开始智能推荐图表类型...") - - tables = state.extracted_tables - if not tables: - log.warning("[chart_type_recommender] 没有表格数据,跳过") - return state - - image_paths = [t["image_path"] for t in tables if "image_path" in t] - - # 这里直接用asyncio原生实现了并行,后面可以考虑改成更加符合langgraph的实现,使用Send API - - @dataclass - class ChartTypeRecommenderState: - request: Any - pre_tool_results: Dict[str, Any] - table: Dict - chart_type_recommender: Any = None - agent_results: Dict = field(default_factory=dict) - chart_configs: Dict = field(default_factory=dict) # 这里存储最终结果:{table_id: chart_config} - - async def task(ctr_state: ChartTypeRecommenderState): - table = ctr_state.table - try: - if "image_path" not in table: - log.error(f"[chart_type_recommender] 表格缺少图片路径: {table}") - return {table["table_id"]: None} - input_image = table["image_path"] - - vlm_config = { - "mode": "understanding", - "input_image": input_image - } - - agent = create_chart_type_recommender( - tool_manager=get_tool_manager(), - model_name=getattr(state.request, "chart_model", "deepseek-v3.2"), - temperature=0.1, - max_tokens=2048, - vlm_config=vlm_config - ) - - ctr_state = await agent.execute(state=ctr_state) - - result = ctr_state.chart_configs - - return result - except Exception as e: - log.error(f"[chart_type_recommender] 处理表格出错: {e}") - return {table["table_id"]: None} - - - # 手动为每个并行节点注入 pre_tool_results - - states = [ - ChartTypeRecommenderState( - request=state.request, - table=table, - pre_tool_results={ - "table_info": {"table_id": table["table_id"]}, - "paper_idea": state.paper_idea - }, - ) - for table in tables - ] - - print(f"states: {states}") - - tasks = [task(s) for s in states] - results = await asyncio.gather(*tasks) - # 过滤掉失败的节点的返回值 - results = [ - result - for result in results if result is not None - for key, value in result.items() if value is not None - ] - results = reduce(lambda x, y: {**x, **y}, results, {}) # 合并结果为一个大字典 - state.chart_configs = results - - return state - - - async def code_executor_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 节点 6: 执行代码生成图表 - 调用 chart_code_generator Agent 智能生成图表 - """ - log.info("[code_executor] 开始执行代码生成图表...") - - tables = state.extracted_tables - if not tables: - log.warning("[code_executor] 没有表格数据,跳过") - return state - - log.info(f"[code_executor] 共有 {len(tables)} 个表格待处理") - - image_paths = [t["image_path"] for t in tables if "image_path" in t] - - output_path = Path(state.result_path).resolve() - charts_dir = output_path / "charts" - charts_dir.mkdir(exist_ok=True) - - # 创建中间结果目录 - intermediate_dir = output_path / "chart_intermediate" - intermediate_dir.mkdir(exist_ok=True) - - generated_charts = [] - - # 获取论文核心思想 - paper_idea = state.paper_idea or "No paper idea extracted" - - @dataclass - class ChartCodeGeneratorState: - request: Any - table: Dict - result_path: str - pre_tool_results: Dict[str, Any] = field(default_factory=dict) - chart_code_generator: Any = None - agent_results: Dict = field(default_factory=dict) - generated_codes: Dict[str, Dict[str, Any]] = field(default_factory=dict) # 生成的代码列表 - - async def task(ccg_state: ChartCodeGeneratorState): - table = ccg_state.table - table_id = table['table_id'] - caption = table.get('caption', '') - chart_config = ccg_state.pre_tool_results.get("chart_config", {}) - - log.info(f"[code_executor] 处理表格: {table_id}") - - try: - input_image = table["image_path"] - - vlm_config = { - "mode": "understanding", - "input_image": input_image - } - - # 调用 Agent - chart_code_agent = create_chart_code_generator( - tool_manager=get_tool_manager(), - model_name=getattr(state.request, "chart_model", "deepseek-v3.2"), - temperature=0.0, - max_tokens=4096, - vlm_config=vlm_config, - ) - - ccg_state = await chart_code_agent.execute(state=ccg_state) - - # 获取生成的代码 - if ccg_state.generated_codes: - code_entry = ccg_state.generated_codes[table_id] # 获取刚生成的代码 - code = code_entry.get('code', '') - description = code_entry.get('description', '') - log.info(f"[code_executor] 生成代码长度: {len(code)} 字符") - log.info(f"[code_executor] 代码描述: {description}") - else: - log.error(f"[code_executor] chart_code_generator 未返回代码") - raise Exception(f"chart_code_generator 未返回代码") - - # 4. 保存中间结果 - intermediate_file = intermediate_dir / f"{table_id}_intermediate.json" - intermediate_data = { - "table_id": table_id, - "timestamp": str(Path(ccg_state.result_path).name), - - # 表格数据 - "table_data": { - "caption": caption, - }, - - # Agent 推荐结果 - "chart_config": chart_config, - - # 生成的代码 - "generated_code": code, - "code_description": description, - } - - with open(intermediate_file, 'w', encoding='utf-8') as f: - json.dump(intermediate_data, f, ensure_ascii=False, indent=2) - log.info(f"[code_executor] 保存中间结果: {intermediate_file}") - - # 保存代码文件(便于查看和调试) - code_file = intermediate_dir / f"{table_id}_code.py" - with open(code_file, 'w', encoding='utf-8') as f: - f.write(code) - log.debug(f"[code_executor] 保存代码文件: {code_file}") - - # 5. 执行代码生成图表 - # 需要将 output_path, headers, rows 注入到代码中 - chart_path = (charts_dir / f"{table_id}.png").resolve() - - # 构建完整的可执行代码 - exec_code = f""" -# Auto-generated code execution wrapper -output_path = {repr(str(chart_path))} - -# Generated chart code -{code} -""" - - result = execute_matplotlib_code( - code=exec_code, - output_path=chart_path, - timeout=30, - ) - - if result['success']: - generated_charts.append(str(chart_path)) - log.info(f"[code_executor] 生成图表: {chart_path}") - - # 更新中间结果,添加执行状态 - intermediate_data["execution_result"] = { - "success": True, - "chart_path": str(chart_path), - "error": None - } - else: - log.error(f"[code_executor] 生成图表失败 ({table_id}): {result['error']}") - - # 更新中间结果,添加错误信息 - intermediate_data["execution_result"] = { - "success": False, - "chart_path": None, - "error": result['error'] - } - - # 重新保存中间结果(包含执行结果) - with open(intermediate_file, 'w', encoding='utf-8') as f: - json.dump(intermediate_data, f, ensure_ascii=False, indent=2) - - code = ccg_state.generated_codes if ccg_state.generated_codes else {} - chart_path = {table_id: chart_path} - return (code, chart_path) - - except Exception as e: - log.error(f"[code_executor] 处理表格 {table_id} 时出错: {e}") - import traceback - log.error(f"[code_executor] 错误堆栈:\n{traceback.format_exc()}") - return ({table_id: None}, {table_id: None}) - - # 过滤掉不适合生成图表的表格,定义匿名函数封装复杂提取逻辑,提高代码可读性 - get_chart_config = lambda x: state.chart_configs.get(x.get("table_id"), {}) - is_suitable = lambda x: get_chart_config(x).get("is_suitable_for_chart", True) - - # Debug: Print paper_idea length instead of full content to avoid log truncation - log.info(f"[code_executor] paper_idea length: {len(state.paper_idea) if state.paper_idea else 0} characters") - - states = [ - ChartCodeGeneratorState( - request=state.request, - table=table, - result_path=state.result_path, - pre_tool_results={ - "paper_idea": state.paper_idea, - "chart_config": get_chart_config(table), - "table_caption": table.get("caption", ""), - } - ) - for table in tables if is_suitable(table) - ] - - tasks = [task(s) for s in states] - generated_results = await asyncio.gather(*tasks) - generated_code = [result[0] for result in generated_results] - generated_charts = [result[1] for result in generated_results] - - # 过滤掉失败节点的值 - 修复:正确过滤 None 值 - code_results = {} - for code_dict in generated_code: - if code_dict: - for table_id, code_entry in code_dict.items(): - if code_entry is not None: - code_results[table_id] = code_entry - - chart_results = {} - for chart_dict in generated_charts: - if chart_dict: - for table_id, chart_path in chart_dict.items(): - if chart_path is not None: - chart_results[table_id] = chart_path - - state.generated_code = code_results - state.generated_charts = chart_results - - log.info(f"[code_executor] 完成,共生成 {len(generated_charts)} 个图表") - log.info(f"[code_executor] 中间结果保存在: {intermediate_dir}") - - return state - - async def post_stylize_node(state: Paper2FigureState) -> Paper2FigureState: - """调用Nano Banana模型对生成的图表进行风格化,更美观""" - log.info(f"[post_stylize] 开始") - # 获取生成的图表路径列表,用于分发任务 - chart_paths = state.generated_charts - - if not chart_paths: - log.warning("[post_stylize] 没有图表需要风格化,跳过") - return state - - save_dir = Path(state.result_path) / Path("stylized_charts") - save_dir.mkdir(parents=True, exist_ok=True) - - stylize_prompt = f"把这张统计图放大字体,并进行以 {state.request.style} 为主题的风格化,让它更美观,但需要保障数据的准确性,避免恶性的溢出和重叠" - - async def stylize_task(table_id: str, save_dir_path: str, chart_path: str): - save_dir_p = Path(save_dir_path) - chart_path_p = Path(chart_path) - save_path = save_dir_p / chart_path_p.name.replace(".png", "_stylized.png") - - log.info(f"[post_stylize] 正在风格化图表: {table_id}") - - log.critical(f"image_path: {chart_path_p}") - - try: - b64_result = await generate_or_edit_and_save_image_async( - prompt=stylize_prompt, - save_path=str(save_path), - api_url=state.request.chat_api_url, - api_key=state.request.api_key, - model=state.request.gen_fig_model, - image_path=str(chart_path_p), - use_edit=True - ) - - log.info(f"[post_stylize] 图表风格化完成: {table_id}") - return {table_id: [b64_result, save_path]} - - except Exception as e: - log.error(f"[post_stylize] {table_id} 图表风格化出错: {e}") - return {table_id: None} - - tasks = [stylize_task(table_id, str(save_dir), str(chart_path)) for table_id, chart_path in chart_paths.items()] - results = await asyncio.gather(*tasks) - # 过滤掉失败的图表 - results = [ - result - for result in results if result is not None - for table_id, stylize_result in result.items() if stylize_result is not None - ] - - state.stylize_results = reduce(lambda x, y: {**x, **y}, results, {}) - - log.info(f"[post_stylize] 完成") - return state - - async def assemble_to_ppt(state: Paper2FigureState) -> Paper2FigureState: - log.info(f"[assemble_to_ppt] 开始") - - try: - from pptx.util import Pt - from pptx.enum.text import PP_ALIGN - from pptx.dml.color import RGBColor - - # 获取生成的图表路径 - chart_paths = state.generated_charts - if not chart_paths: - log.warning("[assemble_to_ppt] 没有生成的图表,跳过") - return state - - # 获取风格化图片路径 - stylized_charts = state.stylize_results - - # 创建 PPT - prs = Presentation() - prs.slide_width = Inches(10) - prs.slide_height = Inches(7.5) - - def add_title_slide(title_text, subtitle_text=""): - """添加标题页""" - title_slide_layout = prs.slide_layouts[0] # 标题布局 - slide = prs.slides.add_slide(title_slide_layout) - - # 设置标题 - title = slide.shapes.title - title.text = title_text - title.text_frame.paragraphs[0].font.size = Pt(44) - title.text_frame.paragraphs[0].font.bold = True - title.text_frame.paragraphs[0].alignment = PP_ALIGN.CENTER - - # 设置副标题 - if subtitle_text and len(slide.placeholders) > 1: - subtitle = slide.placeholders[1] - subtitle.text = subtitle_text - subtitle.text_frame.paragraphs[0].font.size = Pt(24) - subtitle.text_frame.paragraphs[0].alignment = PP_ALIGN.CENTER - - return slide - - def add_image_slide(image_path, title_text=""): - """添加图片页""" - blank_slide_layout = prs.slide_layouts[6] # 空白布局 - slide = prs.slides.add_slide(blank_slide_layout) - - # 添加标题(如果有) - if title_text: - title_box = slide.shapes.add_textbox( - Inches(0.5), Inches(0.2), Inches(9), Inches(0.8) - ) - title_frame = title_box.text_frame - title_frame.text = title_text - title_frame.paragraphs[0].font.size = Pt(24) - title_frame.paragraphs[0].font.bold = True - title_frame.paragraphs[0].alignment = PP_ALIGN.CENTER - - # 获取图片尺寸 - img = Image.open(image_path) - img_width, img_height = img.size - - # 计算适合幻灯片的尺寸(保持宽高比) - slide_width = prs.slide_width - slide_height = prs.slide_height - - # 为标题留出空间 - available_height = slide_height - Inches(1.2) if title_text else slide_height - Inches(0.5) - max_width = slide_width - Inches(1) - max_height = available_height - - # 计算缩放比例 - width_ratio = max_width / img_width - height_ratio = max_height / img_height - scale = min(width_ratio, height_ratio) - - # 计算最终尺寸 - final_width = int(img_width * scale) - final_height = int(img_height * scale) - - # 居中放置 - left = (slide_width - final_width) / 2 - top_offset = Inches(1.2) if title_text else Inches(0.5) - top = top_offset + (available_height - final_height) / 2 - - # 添加图片 - slide.shapes.add_picture( - str(image_path), - left, - top, - width=final_width, - height=final_height - ) - - return slide - - # 1. 添加总标题页 - add_title_slide( - "论文图表生成结果", - "Paper2ExpFigure Workflow Results" - ) - - # 2. 添加原始图表部分标题页 - add_title_slide( - "原始实验图表", - "Original Experimental Charts" - ) - - # 3. 添加原始图表 - for table_id, chart_path in chart_paths.items(): - if not os.path.exists(chart_path): - log.warning(f"[assemble_to_ppt] 图表文件不存在: {chart_path}") - continue - - add_image_slide(chart_path, f"图表 {table_id}") - log.info(f"[assemble_to_ppt] 添加原始图表到 PPT: {table_id}") - - # 4. 如果有风格化图片,添加风格化部分 - if stylized_charts: - # 添加风格化图表部分标题页 - add_title_slide( - "风格化图表", - "Stylized Charts (Vintage Print Style)" - ) - - # 添加风格化图表 - for table_id, stylized_path in stylized_charts.items(): - stylized_path = stylized_path[1] - if not os.path.exists(stylized_path): - log.warning(f"[assemble_to_ppt] 风格化图表文件不存在: {stylized_path}") - continue - - add_image_slide(stylized_path, f"风格化图表 {table_id}") - log.info(f"[assemble_to_ppt] 添加风格化图表到 PPT: {table_id}") - - # 保存 PPT - output_path = Path(state.result_path) - ppt_path = output_path / "generated_charts.pptx" - prs.save(str(ppt_path)) - - total_slides = len(prs.slides) - log.info(f"[assemble_to_ppt] PPT 已保存: {ppt_path}") - log.info(f"[assemble_to_ppt] 共创建 {total_slides} 张幻灯片") - log.info(f"[assemble_to_ppt] 包含 {len(chart_paths)} 个原始图表和 {len(stylized_charts)} 个风格化图表") - - state.ppt_path = str(ppt_path) - - except ImportError: - log.error("[assemble_to_ppt] 缺少 python-pptx 库,请安装: pip install python-pptx") - except Exception as e: - log.error(f"[assemble_to_ppt] 生成 PPT 失败: {e}") - import traceback - traceback.print_exc() - - log.info(f"[assemble_to_ppt] 完成") - return state - - - # ============================================================== - # 注册 nodes / edges - # ============================================================== - - nodes = { - "_start_": _start_, - "pdf_to_images_node": pdf_to_images_node, - "mineru_extract_node": mineru_extract_node, - "figure_input_node": figure_input_node, - "text_to_table_image_node": text_to_table_image_node, - "table_extractor_node": table_extractor_node, - "paper_idea_extractor": paper_idea_extractor, - "chart_type_recommender_node": chart_type_recommender_node, - "code_executor_node": code_executor_node, - "post_stylize_node": post_stylize_node, - "assemble_to_ppt": assemble_to_ppt, - "_end_": lambda state: state, - } - - # 边定义 - edges = [ - # PDF 流程 - ("pdf_to_images_node", "mineru_extract_node"), - ("mineru_extract_node", "table_extractor_node"), - - # FIGURE 流程 - 直接到 chart_type_recommender - ("figure_input_node", "chart_type_recommender_node"), - - # TEXT 流程 - 直接到 chart_type_recommender - ("text_to_table_image_node", "chart_type_recommender_node"), - - # PDF 流程继续 - ("table_extractor_node", "paper_idea_extractor"), - ("paper_idea_extractor", "chart_type_recommender_node"), - - # 公共流程 - ("chart_type_recommender_node", "code_executor_node"), - ("code_executor_node", "post_stylize_node"), - ("post_stylize_node", "assemble_to_ppt"), - - # 最终节点 - ("assemble_to_ppt", "_end_"), - ] - - # 添加条件路由 - builder.add_nodes(nodes).add_edges(edges).add_conditional_edge("_start_", _route_by_input_type) - - return builder diff --git a/dataflow_agent/workflow/wf_paper2figure_image_only.py b/dataflow_agent/workflow/wf_paper2figure_image_only.py deleted file mode 100644 index fe12062..0000000 --- a/dataflow_agent/workflow/wf_paper2figure_image_only.py +++ /dev/null @@ -1,228 +0,0 @@ -""" -icongen workflow (Image Only) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Only generates the figure (and layout template), stops before SAM/PPT generation. -Used for preview/confirmation step. - -1. Idea Extraction (PDF/Text) -2. Prompt Generation -3. Image Generation (Content + Layout Template) -""" - -from __future__ import annotations -import asyncio -import json -import os -import time -from pathlib import Path - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.workflow.registry import register -from dataflow_agent.agentroles import create_graph_agent, create_react_agent -from dataflow_agent.toolkits.multimodaltool.req_img import generate_or_edit_and_save_image_async -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root - -import fitz -import PyPDF2 - -log = get_logger(__name__) - -def _ensure_result_path(state: Paper2FigureState) -> str: - raw = getattr(state, "result_path", None) - if raw: - return raw - root = get_project_root() - ts = int(time.time()) - base_dir = (root / "outputs" / "paper2figure" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - -@register("paper2fig_image_only") -def create_p2fig_image_only_graph() -> GenericGraphBuilder: - builder = GenericGraphBuilder(state_model=Paper2FigureState, - entry_point="_start_") - - # ---------------------------------------------------------------------- - # TOOLS (pre_tool definitions) - # ---------------------------------------------------------------------- - @builder.pre_tool("paper_content", "paper_idea_extractor") - def _get_abstract_intro(state: Paper2FigureState): - try: - with open(state.paper_file, 'rb') as f: - reader = PyPDF2.PdfReader(f) - paper_title = reader.metadata.get('/Title', 'Unknown Title') - except Exception: - paper_title = "Unknown Title" - - file_path = state.paper_file - pdf_document = fitz.open(file_path) - - text = "" - for page_num in range(min(10, len(pdf_document))): - page = pdf_document.load_page(page_num) - text += page.get_text("text") - - content = text.strip() - final_text = ( - f"The title of the paper is {paper_title}\n\n" - f"Here's first ten page content: {content}" - ) - log.info(f"{final_text}") - return final_text - - @builder.pre_tool("paper_idea", "figure_desc_generator") - def _get_paper_idea(state: Paper2FigureState): - # 根据请求语言添加指令 - lang = getattr(state.request, "language", "zh") - log.critical(f'[image_only]: lang {lang}') - lang_instruction = "" - if lang == "zh": - lang_instruction = "\n\nIMPORTANT: The text content inside the generated figure MUST be in Chinese (Simplified Chinese). Please ensure all labels, titles, and descriptions in the figure description are in Chinese." - else: - lang_instruction = "\n\nIMPORTANT: The text content inside the generated figure MUST be in English. Please ensure all labels, titles, and descriptions in the figure description are in English." - - # 如果是图片编辑模式,state.paper_idea 可能被设为了 "Image Edit Mode",此时不应追加指令到 meaningless string - # 但如果是 TEXT 模式,则需要。如果是 PDF 模式,paper_idea 是提取出来的摘要。 - - return state.paper_idea + lang_instruction - - # ============================================================== - # NODES - # ============================================================== - async def paper_idea_extractor_node(state: Paper2FigureState) -> Paper2FigureState: - paper_idea_extractor = create_graph_agent("paper_idea_extractor") - state = await paper_idea_extractor.execute(state, use_agent=True) - return state - - async def figure_desc_generator_node(state: Paper2FigureState) -> Paper2FigureState: - figure_desc_generator = create_react_agent("figure_desc_generator", - max_retries=5, - model_name=getattr(state.request, "fig_desc_model", "gpt-5.1")) - state = await figure_desc_generator.execute(state, use_agent=True) - return state - - async def figure_generator_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 调用图像模型生成/编辑单张图(内容 + 布局模版),增加轻量 retry + 降低超时时间。 - - 参考 pdf2ppt_optimized 中 _call_image_api_with_retry 的做法,这里做一个本地版: - - 单次 HTTP 调用 timeout 控制在 60s(通过参数下传) - - 最多重试 3 次,每次失败打印详细日志 - """ - # 安全获取 figure_desc_generator 的结果,可能因 input_type=FIGURE 跳过了该节点 - fd_gen_result = state.agent_results.get("figure_desc_generator") - prompt = "" - if fd_gen_result: - prompt = fd_gen_result.get("results", {}).get("fig_desc", "") - - safe_prompt = json.dumps(prompt, ensure_ascii=False) if prompt else "" - - edit_prompt = state.request.get("edit_prompt") - image_path = state.request.get("prev_image") - - final_prompt = edit_prompt if image_path else safe_prompt - - log.info( - f"[p2f_image_only] final_prompt(len={len(final_prompt)}), " - f"edit_prompt_len={len(edit_prompt or '')}, image_path={image_path}, " - f"safe_prompt_len={len(safe_prompt or '')}" - ) - - result_root = Path(_ensure_result_path(state)).resolve() - result_root.mkdir(parents=True, exist_ok=True) - - # 1) Generate Content Image - fig_name = f"fig_{int(time.time())}.png" - save_path = str((result_root / fig_name).resolve()) - - api_url = state.request.chat_api_url - api_key = state.request.chat_api_key or os.getenv("DF_API_KEY") - model = state.request.gen_fig_model - aspect_ratio = state.aspect_ratio - - async def _call_image_api_with_retry(coro_factory, retries: int = 3, delay: float = 2.0) -> bool: - """ - 本地轻量级 retry:仅负责 httpx / 网络级错误或服务超时。 - 失败时不会抛到外层,而是返回 False,由 workflow 决定如何处理。 - """ - last_err = None - for attempt in range(1, retries + 1): - try: - log.info(f"[p2f_image_only] image api attempt {attempt}/{retries} ...") - await coro_factory() - log.info("[p2f_image_only] image api succeed") - return True - except Exception as e: - last_err = e - log.error(f"[p2f_image_only] image api failed attempt {attempt}/{retries}: {e}") - if attempt < retries: - await asyncio.sleep(delay) - log.error(f"[p2f_image_only] image api failed after {retries} attempts: {last_err}") - return False - - async def _gen_image(): - await generate_or_edit_and_save_image_async( - prompt=final_prompt, - save_path=save_path, - aspect_ratio=aspect_ratio, - api_url=api_url, - api_key=api_key, - model=model, - image_path=image_path, - use_edit=True if image_path else False, - timeout=60, - resolution='2K' - ) - - ok = await _call_image_api_with_retry(_gen_image) - if not ok: - # 将失败信息写入 state,避免直接 500 - state.agent_results["gen_img_error"] = { - "msg": "image generation failed after retries", - "save_path": save_path, - } - log.error("[p2f_image_only] image generation failed, see previous logs for details") - return state - - state.agent_results["gen_img"] = {"path": save_path} - state.fig_draft_path = save_path - - return state - - # ============================================================== - # Registry - # ============================================================== - def set_entry_node(state: Paper2FigureState) -> str: - if state.request.input_type == "PDF": - return "paper_idea_extractor" - elif state.request.input_type == "TEXT": - return "figure_desc_generator" - elif state.request.input_type == "FIGURE": - return "figure_generator" - else: - log.error(f"Invalid input type: {state.request.input_type}. Only PDF, TEXT and FIGURE are supported.") - return "_end_" - - def _init_result_path(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - return state - - nodes = { - '_start_': _init_result_path, - "paper_idea_extractor": paper_idea_extractor_node, - "figure_desc_generator": figure_desc_generator_node, - "figure_generator": figure_generator_node, - '_end_': lambda state: state, - } - - edges = [ - ("paper_idea_extractor", "figure_desc_generator"), - ("figure_desc_generator", "figure_generator"), - ("figure_generator", "_end_"), # End after image generation - ] - - builder.add_nodes(nodes).add_edges(edges).add_conditional_edge("_start_", set_entry_node) - return builder diff --git a/dataflow_agent/workflow/wf_paper2figure_with_sam.py b/dataflow_agent/workflow/wf_paper2figure_with_sam.py deleted file mode 100644 index 0041fa1..0000000 --- a/dataflow_agent/workflow/wf_paper2figure_with_sam.py +++ /dev/null @@ -1,935 +0,0 @@ -""" -icongen workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-10-27 11:11:56 - -1. 在 **TOOLS** 区域定义需要暴露给 Prompt 的前置工具 -2. 在 **NODES** 区域实现异步节点函数 (await-able) -3. 在 **EDGES** 区域声明有向边 -4. 最后返回 builder.compile() 或 GenericGraphBuilder -""" - -from __future__ import annotations -import asyncio -import json -import os -from dataflow_agent.state import MainState, Paper2FigureState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder - - -from dataflow_agent.workflow.registry import register -# from dataflow_agent.agentroles import get_agent_cls, create_agent - -from dataflow_agent.toolkits.tool_manager import get_tool_manager -from langchain.tools import tool -from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode, tools_condition - -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger - -from dataflow_agent.toolkits.multimodaltool.req_img import generate_or_edit_and_save_image_async -from dataflow_agent.toolkits.multimodaltool.bg_tool import local_tool_for_bg_remove, local_tool_for_raster_to_svg, free_bg_rm_model -from dataflow_agent.toolkits.multimodaltool.sam_tool import segment_layout_boxes, segment_layout_boxes_server, free_sam_model -from dataflow_agent.toolkits.multimodaltool.mineru_tool import ( - svg_to_emf, - recursive_mineru_layout, -) -from dataflow_agent.agentroles import create_graph_agent,create_react_agent - -import re, pdfplumber, PyPDF2, time, shutil, fitz -import numpy as np -from PIL import Image - -from dataflow_agent.utils import ( - build_output_directory, - add_image_element, - add_text_element, - setup_presentation_size, - get_project_root, - pixels_to_inches, -) - -from pathlib import Path -import time, random -from pptx import Presentation -from pptx.dml.color import RGBColor -from pptx.util import Inches - - -log = get_logger(__name__) - -TEMPLATE_EDIT_PROMPT = ( -"Transform the original image into a pure layout made ONLY of solid colored blocks:\n" -"1. Keep only the outermost rectangles and arrows (if they exist).\n" -"2. Delete everything inside them: all titles, subtitles, texts, icons, illustrations, and any inner shapes.\n" -"3. Turn each remaining outer shape into a solid color block; remove borders if possible.\n" -"4. Keep the layout exactly the same: same positions, sizes, alignment, and spacing.\n" -"5. Do NOT add any text, labels, or symbols anywhere.\n" -"Finally, output a description of this empty color-block template (no text content at all)." -) - -def _ensure_result_path(state: Paper2FigureState) -> str: - """ - 统一本次 paper2figure_with_sam workflow 的根输出目录: - - 如果 state.result_path 已存在(通常由调用方传入),直接使用; - - 否则:使用 get_project_root() / "outputs" / "paper2figure" / , - 并写回 state.result_path,后续节点共享同一目录。 - """ - raw = getattr(state, "result_path", None) - if raw: - return raw - - root = get_project_root() - ts = int(time.time()) - base_dir = (root / "outputs" / "paper2figure" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - -def _ts_name(stem: str, ext: str = ".png") -> str: - timestamp = int(time.time()) # 获取当前时间戳(秒) - return f"./{stem}{timestamp}{ext}" - -@register("paper2fig_with_sam") -def create_p2fig_graph() -> GenericGraphBuilder: # noqa: N802 - """ - Workflow factory: dfa run --wf paper2fig - """ - builder = GenericGraphBuilder(state_model=Paper2FigureState, - entry_point="_start_") # 自行修改入口 - - # ---------------------------------------------------------------------- - # TOOLS (pre_tool definitions) - # ---------------------------------------------------------------------- - @builder.pre_tool("paper_content", "paper_idea_extractor") - def _get_abstract_intro(state: Paper2FigureState): - """ - Robustly extract Abstract + Introduction from PDF. - """ - - # 1. Read metadata title - try: - with open(state.paper_file, 'rb') as f: - reader = PyPDF2.PdfReader(f) - paper_title = reader.metadata.get('/Title', 'Unknown Title') - except Exception: - paper_title = "Unknown Title" - - # Open the PDF file using the path from state - file_path = state.paper_file - pdf_document = fitz.open(file_path) - - # Extract text from the first 10 pages - text = "" - for page_num in range(min(10, len(pdf_document))): - page = pdf_document.load_page(page_num) - text += page.get_text("text") - - content = text.strip() - - final_text = ( - f"The title of the paper is {paper_title}\n\n" - f"Here's first ten page content: {content}" - ) - - log.info(f"{final_text}") - return final_text - - @builder.pre_tool("paper_idea", "figure_desc_generator") - def _get_paper_idea(state: Paper2FigureState): - """ - Return paper ideas summary. - """ - return state.paper_idea - - # ============================================================== - # NODES - # ============================================================== - async def paper_idea_extractor_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 提取论文的关键贡献点 - """ - paper_idea_extractor = create_graph_agent("paper_idea_extractor") - state = await paper_idea_extractor.execute(state, use_agent=True) - return state - - async def figure_desc_generator_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 图标提示词生成器节点 - """ - figure_desc_generator = create_react_agent("figure_desc_generator", - max_retries=5, - model_name="gpt-5.1") - state = await figure_desc_generator.execute(state, use_agent=True) - return state - - async def figure_generator_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 图像生成或编辑节点: - 1) 先生成带内容的图 (fig_draft_path) - 2) 再基于该图进行一次固定提示词的二次编辑,生成空框模板图 (fig_layout_path) - """ - prompt = state.agent_results.get("figure_desc_generator").get("results").get("fig_desc", {}) - safe_prompt = json.dumps(prompt, ensure_ascii=False) # 确保中文字符正常显示 - - edit_prompt = state.request.get("edit_prompt") - image_path = state.request.get("prev_image") - - # 如果是二次编辑,prompt可以为空 - final_prompt = edit_prompt if image_path else safe_prompt - - log.info(f'final_prompt{final_prompt} - edit_prompt:{edit_prompt} - image_path:{image_path} - prompt:{safe_prompt}') - - # 统一输出根目录(outputs/paper2figure/) - result_root = Path(_ensure_result_path(state)) - result_root.mkdir(parents=True, exist_ok=True) - - # 1) 生成带内容的图,直接存到 result_root - fig_name = f"fig_{int(time.time())}.png" - save_path = str((result_root / fig_name).resolve()) - - await generate_or_edit_and_save_image_async( - prompt=final_prompt, - save_path=save_path, - aspect_ratio=state.aspect_ratio, - api_url=state.request.chat_api_url, - api_key=state.request.chat_api_key or os.getenv("DF_API_KEY") , - model=state.request.gen_fig_model, - image_path=image_path, - use_edit=True if image_path else False - ) - state.agent_results["gen_img"] = {"path": save_path} - state.fig_draft_path = save_path - - # 2) 基于第一次生成的图,做一次“空模板”二次编辑,也放在 result_root - # TEMPLATE_EDIT_PROMPT = ( - # "Keep only the outermost rectangles and arrows(if any in the original box).\n" - # "Remove all inner content including title, subtitles, icons, explainary texts and all that.\n" - # "Keep the layout exactly the same.\n" - # "Output a description of an empty template composed of these boxes." - # ) - - layout_name = f"layout_{int(time.time())}.png" - layout_save_path = str((result_root / layout_name).resolve()) - await generate_or_edit_and_save_image_async( - prompt=TEMPLATE_EDIT_PROMPT, - save_path=layout_save_path, - aspect_ratio=state.aspect_ratio, - api_url=state.request.chat_api_url, - api_key=state.request.chat_api_key or os.getenv("DF_API_KEY") , - model=state.request.gen_fig_model, - image_path=save_path, - use_edit=True, - ) - state.fig_layout_path = layout_save_path - state.agent_results["gen_img_template"] = {"path": layout_save_path} - - return state - - async def figure_layout_sam_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 针对二次编辑后的空框模板图 (fig_layout_path) 进行: - SAM 自动分割 -> 过滤 -> 裁剪子图 -> PNG->SVG->EMF, - 结果写入 state.layout_items,仅作为 PPT 背景框架层。 - - 注意: - - segment_layout_boxes 返回的 bbox 是基于 layout 图像尺寸的归一化坐标 [0,1]; - - 这里显式转换一份像素坐标 bbox_px,后面插入 PPT 时统一按像素 → 英寸 → Emu 的规则处理, - 和 add_image_element / add_text_element 的坐标系保持一致,避免 EMF 位置/尺寸错乱导致“看不到”的问题。 - """ - try: - if not state.fig_layout_path and state.request.input_type == "FIGURE": - result_root = Path(_ensure_result_path(state)) - result_root.mkdir(parents=True, exist_ok=True) - log.critical(f"[figure_layout_sam] fig_layout_path 为空, 需要更新Layout图") - layout_name = f"layout_{int(time.time())}.png" - layout_save_path = str((result_root / layout_name).resolve()) - await generate_or_edit_and_save_image_async( - prompt="1.Remove all text content; keep only the outermost rectangular frames and arrows (if any).\n" - "2.Keep the layout unchanged.\n" - "3.Change the background color to white.", - save_path=layout_save_path, - aspect_ratio=state.aspect_ratio, - api_url=state.request.chat_api_url, - api_key=state.request.chat_api_key or os.getenv("DF_API_KEY") , - model=state.request.gen_fig_model, - image_path=f"{get_project_root()}/{state.fig_draft_path}", - use_edit=True, - ) - state.fig_layout_path = layout_save_path - state.agent_results["gen_img_template"] = {"path": layout_save_path} - - img_path = Path(state.fig_layout_path) - if not img_path.exists(): - log.error(f"[figure_layout_sam] fig_layout_path 不存在: {img_path}") - return state - - base_dir = Path(_ensure_result_path(state)) - out_dir = base_dir / "layout_items" - out_dir.mkdir(parents=True, exist_ok=True) - - sam_ckpt = f'{get_project_root()}/sam_b.pt' - # SAM LB Port 8020 - sam_server_urls = ["http://localhost:8020"] - - # 1. SAM 分割 + 过滤 + 裁剪子图 (优先使用远程服务) - try: - layout_items = segment_layout_boxes_server( - image_path = str(img_path), - output_dir = str(out_dir), - server_urls = sam_server_urls, - checkpoint = sam_ckpt, - min_area = 200, - min_score = 0.0, - iou_threshold = 0.2, - top_k = 15, - nms_by = "mask", - ) - except Exception as e: - log.error(f"[figure_layout_sam] Remote SAM failed: {e}. Fallback to local.") - # Fallback to local if server fails - layout_items = segment_layout_boxes( - image_path = str(img_path), - output_dir = str(out_dir), - checkpoint = sam_ckpt, - # 这里的参数可以根据 mask_detail_level 调整 - min_area = 200, - min_score = 0.0, - iou_threshold = 0.2, - top_k = 15, - nms_by = "mask", - ) - # 只有本地运行时才需要手动释放模型 - free_sam_model(checkpoint= sam_ckpt) - - log.info(f"[figure_layout_sam] SAM 分割结果: {len(layout_items)} 个布局元素") - - # layout 图实际像素尺寸,用于把归一化 bbox 转为像素 bbox - try: - layout_img = Image.open(str(img_path)) - layout_w, layout_h = layout_img.size - except Exception as e: - log.error(f"[figure_layout_sam] 打开 layout 图失败: {e}") - layout_w, layout_h = 1024, 1024 # 兜底,和默认 slide 尺寸一致 - - # 2. 每个 layout PNG 转 SVG -> EMF,并补充像素坐标 bbox_px - for idx, it in enumerate(layout_items): - png_path = it.get("png_path") - if not png_path: - continue - - # 将归一化 bbox 映射到像素坐标,和 fig_mask 的像素 bbox 保持一致 - bbox = it.get("bbox") - if bbox and len(bbox) == 4: - x1n, y1n, x2n, y2n = bbox - x1 = int(round(x1n * layout_w)) - y1 = int(round(y1n * layout_h)) - x2 = int(round(x2n * layout_w)) - y2 = int(round(y2n * layout_h)) - if x2 > x1 and y2 > y1: - it["bbox_px"] = [x1, y1, x2, y2] - else: - log.warning(f"[figure_layout_sam] 无效 bbox: {bbox} -> 像素 [{x1},{y1},{x2},{y2}]") - - svg_path = out_dir / f"layout_{idx}.svg" - svg_abs = local_tool_for_raster_to_svg( - { - "image_path": png_path, - "output_svg": str(svg_path), - "colormode": "color", - "hierarchical": "stacked", - "mode": "spline", - } - ) - it["svg_path"] = svg_abs - - emf_path = out_dir / f"layout_{idx}.emf" - try: - emf_abs = svg_to_emf(svg_abs, str(emf_path)) - it["emf_path"] = emf_abs - except Exception as e: - log.error(f"[figure_layout_sam] svg_to_emf failed for {svg_abs}: {e}") - it["emf_path"] = None - - state.layout_items = layout_items - log.info(f"[figure_layout_sam] 共生成 {len(layout_items)} 个布局元素") - # log.info(f'state.layout_items : {state.layout_items}') - - except Exception as e: - log.error(f"[figure_layout_sam] Critical Failure, fallback to empty layout_items: {e}") - state.layout_items = [] - - return state - - async def figure_mask_generator_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 生成Figure进行元素切割,并提取 bbox + image_path 信息,递归处理子图。 - 使用 MinerU HTTP 对原始带内容的图 (fig_draft_path) 做解析,得到内容层元素。 - 规则: - - 标题块(type == 'title') 保留为 text; - - 其它所有块一律从顶层图裁剪出子图,当作 image,用于 icon / 局部视觉元素。 - """ - try: - img_path = Path(state.fig_draft_path) - if not img_path.exists(): - log.error(f"[figure_mask] fig_draft_path 不存在: {img_path}") - return state - - # MinerU 所有中间结果统一放在本次 outputs 下 - base_dir = Path(_ensure_result_path(state)) - out_dir = base_dir / "mineru_recursive" - out_dir.mkdir(parents=True, exist_ok=True) - log.info(f"[figure_mask] MinerU 输出目录: {out_dir}") - - # MinerU 端口:优先从 state.request.mineru_port 读取,默认 8010 - port = getattr(state.request, "mineru_port", 8010) - max_depth = getattr(state, "mask_detail_level", 3) - - log.critical(f"mask detail level : {max_depth} ") - log.critical(f'[img_path]: {img_path}') - log.critical(f'[mineru_port]: {port}') - - # 1. 调用新的 HTTP MinerU 递归处理,获取元素列表(归一化坐标) - mineru_items = await recursive_mineru_layout( - image_path=str(img_path), - port=port, - max_depth=max_depth, - output_dir=out_dir, - ) - log.info(f"mineru_items : {mineru_items}") - - # 顶层图像尺寸,用于 norm->pixel 映射与裁剪 - top_img = Image.open(state.fig_draft_path) - top_w, top_h = top_img.size - - # 图标原图输出目录 - icons_raw_dir = base_dir / "icons_raw" - icons_raw_dir.mkdir(parents=True, exist_ok=True) - - fig_mask = [] - icon_count = 0 - text_count = 0 - - details = 1 - if state.request.figure_complex == "easy": - details = 1 - elif state.request.figure_complex == "hard": - details = 10 - else: - details = 5 - - # 如果 MinerU 只返回了 小于等于 6 个整体元素:按 SAM 布局切子图,再对每个子图单独跑 MinerU, - # 以便获取该布局块内部的文字和更细粒度元素。底图始终使用 fig_draft_path。 - if len(mineru_items) <= details: - from dataflow_agent.toolkits.multimodaltool.mineru_tool import run_aio_two_step_extract - - layout_items = getattr(state, "layout_items", None) or [] - log.info(f"[figure_mask] mineru_items size = {len(mineru_items)}, 使用 SAM 布局({len(layout_items)} 个)进行二次 MinerU 拆分") - - # 子图保存目录 - sub_root_dir = base_dir / "mineru_sub_images" - sub_root_dir.mkdir(parents=True, exist_ok=True) - - for layout_idx, layout_it in enumerate(layout_items): - # 优先使用在 figure_layout_sam_node 中写入的像素 bbox_px;否则退回 bbox 视为像素坐标 - bbox_px = layout_it.get("bbox_px") or layout_it.get("bbox") - if not bbox_px or len(bbox_px) != 4: - continue - lx1, ly1, lx2, ly2 = bbox_px - # 粗筛 bbox - if lx2 <= lx1 or ly2 <= ly1: - continue - - # 边界裁剪到顶层图尺寸 - lx1 = max(0, min(top_w, int(round(lx1)))) - ly1 = max(0, min(top_h, int(round(ly1)))) - lx2 = max(0, min(top_w, int(round(lx2)))) - ly2 = max(0, min(top_h, int(round(ly2)))) - if lx2 <= lx1 or ly2 <= ly1: - continue - - # 1) 从原始 fig_draft_path 裁出当前布局块子图 - try: - sub_img = top_img.crop((lx1, ly1, lx2, ly2)) - except Exception as e: - log.error(f"[figure_mask] 裁剪 SAM 子图失败 layout_idx={layout_idx}, bbox=({lx1},{ly1},{lx2},{ly2}): {e}") - continue - - sub_dir = sub_root_dir / f"layout_{layout_idx}" - sub_dir.mkdir(parents=True, exist_ok=True) - sub_path = sub_dir / f"sam_sub_{layout_idx}.png" - try: - sub_img.save(sub_path) - except Exception as e: - log.error(f"[figure_mask] 保存 SAM 子图失败 layout_idx={layout_idx}, path={sub_path}: {e}") - continue - - # 2) 对子图再次调用 MinerU(只做一层 two_step_extract,不再递归) - try: - sub_blocks = await run_aio_two_step_extract(str(sub_path), port=port) - except Exception as e: - log.error(f"[figure_mask] 子图 MinerU 解析失败 layout_idx={layout_idx}, path={sub_path}: {e}") - continue - - sub_w, sub_h = sub_img.size - - # 3) 遍历子图内的 MinerU block,映射到整图像素坐标系 - for blk_idx, blk in enumerate(sub_blocks): - blk_type_raw = blk.get("type") or "" - blk_type = blk_type_raw.lower() - bbox_norm = blk.get("bbox") - text = blk.get("text") or blk.get("content") or "" - if not bbox_norm or len(bbox_norm) != 4: - continue - - sx1n, sy1n, sx2n, sy2n = bbox_norm - # 规整到 [0,1],避免越界 - sx1n = max(0.0, min(1.0, float(sx1n))) - sy1n = max(0.0, min(1.0, float(sy1n))) - sx2n = max(0.0, min(1.0, float(sx2n))) - sy2n = max(0.0, min(1.0, float(sy2n))) - if sx2n <= sx1n or sy2n <= sy1n: - continue - - # 子图归一化 -> 子图像素 - sx1 = int(round(sx1n * sub_w)) - sy1 = int(round(sy1n * sub_h)) - sx2 = int(round(sx2n * sub_w)) - sy2 = int(round(sy2n * sub_h)) - if sx2 <= sx1 or sy2 <= sy1: - continue - - # 子图像素 -> 整图像素(加上布局块的偏移) - gx1 = lx1 + sx1 - gy1 = ly1 + sy1 - gx2 = lx1 + sx2 - gy2 = ly1 + sy2 - - # 再次 clamp 到整图范围 - gx1 = max(0, min(top_w, gx1)) - gy1 = max(0, min(top_h, gy1)) - gx2 = max(0, min(top_w, gx2)) - gy2 = max(0, min(top_h, gy2)) - if gx2 <= gx1 or gy2 <= gy1: - continue - - px_bbox = [gx1, gy1, gx2, gy2] - - # 文本块:直接作为 text 元素 - if blk_type in ["title", "text"]: - fig_mask.append( - { - "type": "text", - "bbox": px_bbox, - "text": text, - "text_level": 1 if blk_type == "title" else None, - "page_idx": 0, - } - ) - text_count += 1 - else: - # 非文本块:从顶层图再次裁剪成小图,作为 image 元素 - try: - crop = top_img.crop((gx1, gy1, gx2, gy2)) - icon_path = icons_raw_dir / f"blk_sub_{layout_idx}_{blk_idx}.png" - crop.save(icon_path) - fig_mask.append( - { - "type": "image", - "bbox": px_bbox, - "img_path": str(icon_path), - "page_idx": 0, - } - ) - icon_count += 1 - except Exception as e: - log.error( - f"[figure_mask] 子块裁剪失败 layout_idx={layout_idx}, blk_idx={blk_idx}, bbox={px_bbox}: {e}" - ) - # 兜底:退化为文本元素,保持兼容 - fig_mask.append( - { - "type": "text", - "bbox": px_bbox, - "text": text, - "text_level": None, - "page_idx": 0, - } - ) - text_count += 1 - else: - # 正常路径:MinerU 输出多个元素,仍按原逻辑基于 fig_draft_path 裁剪 - for idx, it in enumerate(mineru_items): - elem_type_raw = it.get("type") or "" - elem_type = elem_type_raw.lower() - bbox = it.get("bbox") - text = (it.get("text") or it.get("content") or "").strip() - - if not bbox or len(bbox) != 4: - continue - - # 归一化 -> 像素坐标(基于原始 fig_ 图尺寸) - x1n, y1n, x2n, y2n = bbox - x1 = int(round(x1n * top_w)) - y1 = int(round(y1n * top_h)) - x2 = int(round(x2n * top_w)) - y2 = int(round(y2n * top_h)) - - if x2 <= x1 or y2 <= y1: - continue - - px_bbox = [x1, y1, x2, y2] - - # 1) 只要有文字内容,一律作为文本元素 - if text: - fig_mask.append( - { - "type": "text", - "bbox": px_bbox, - "text": text, - "text_level": 1 if elem_type == "title" else None, - "page_idx": 0, - } - ) - text_count += 1 - continue - - # 2) 没有任何文字内容的块:一律裁图,当作 image,用于 icon / 元素图层 - try: - crop = top_img.crop((x1, y1, x2, y2)) - icon_path = icons_raw_dir / f"blk_{idx}.png" - crop.save(icon_path) - icon_abs = str(icon_path) - fig_mask.append( - { - "type": "image", - "bbox": px_bbox, - "img_path": icon_abs, - "page_idx": 0, - } - ) - icon_count += 1 - except Exception as e: - log.error(f"[figure_mask] 裁剪子图失败 idx={idx}, bbox={px_bbox}: {e}") - # 兜底:作为普通文本 - fig_mask.append( - { - "type": "text", - "bbox": px_bbox, - "text": text, - "text_level": None, - "page_idx": 0, - } - ) - text_count += 1 - - type_counter = {} - for e in fig_mask: - t = e.get("type") - type_counter[t] = type_counter.get(t, 0) + 1 - - log.info( - f"[figure_mask] fig_mask size = {len(fig_mask)}, " - f"type distribution = {type_counter}, " - f"title_text={text_count}, icons(raw)={icon_count}" - ) - - # 更新 state 的 fig_mask 信息 - state.fig_mask = fig_mask - log.info(f"[figure_mask] 共解析出 {len(fig_mask)} 个元素 (via MinerU HTTP + SAM fallback, pixel bbox + raw icons)") - - except Exception as e: - log.error(f"[figure_mask] Critical Failure, fallback to empty fig_mask: {e}") - state.fig_mask = [] - - return state - - async def figure_icon_bg_remover_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 把Mask里面的图标去除背景 - """ - try: - base_dir = Path(_ensure_result_path(state)) - icons_dir = base_dir / "icons" - icons_dir.mkdir(parents=True, exist_ok=True) - - img_cnt = 0 - for item in state.fig_mask: - if item.get('type') in ['image', 'table']: - img_cnt += 1 - try: - output_path = local_tool_for_bg_remove({ - "image_path": item.get('img_path'), - "model_path": state.request.bg_rm_model, - "output_dir": str(icons_dir) - }) - if output_path: - item['img_path'] = output_path - log.info(f"[figure_icon_bg_remover] background removed: {output_path}") - else: - log.warning(f"[figure_icon_bg_remover] bg remove failed for {item.get('img_path')}") - except Exception as e: - log.warning(f"[figure_icon_bg_remover] Single item bg remove failed: {e}") - - log.info(f"[figure_icon_bg_remover] processed image/table elements: {img_cnt}") - - # 抠图完成后,显式释放 RGB2.0 模型占用的显存 - try: - free_bg_rm_model(model_path=state.request.bg_rm_model) - log.info("[figure_icon_bg_remover] freed RMBG-2.0 model from GPU") - except Exception as e: - log.error(f"[figure_icon_bg_remover] free_bg_rm_model failed: {e}") - - except Exception as e: - log.error(f"[figure_icon_bg_remover] Critical Failure, skipping bg removal: {e}") - - return state - - async def figure_ppt_generation_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 生成单页 PPT: - - 第 1 页:原始组合页(layout EMF + MinerU 文本 + 图像) - - 第 2 页:仅渲染所有 layout_items 的 EMF,用于检查 SAM 框架 - - 第 3 页:直接铺满整页的 full 内容 PNG(state.fig_draft_path) - - 关键点: - - layout_items 在 figure_layout_sam_node 中已经给出了像素坐标 bbox_px; - - 这里完全沿用 add_text_element / add_image_element 使用的“像素 → 英寸 → Emu”规则, - 确保 EMF 背景框和内容层在同一几何坐标系下,避免因为单位不一致导致 EMF 肉眼不可见。 - - full 内容 PNG 页直接使用 fig_draft_path,按原图尺寸换算为英寸铺满整页。 - """ - try: - # 从state获取输出目录(若未设置则自动初始化 outputs/paper2figure/) - output_dir = Path(_ensure_result_path(state)) - output_dir.mkdir(parents=True, exist_ok=True) - - # 生成唯一文件名 - timestamp = int(time.time()) - ppt_filename = f"presentation_{timestamp}.pptx" - ppt_path = output_dir / ppt_filename - - # 创建Presentation对象 - prs = Presentation() - - # 设置PPT尺寸,依据原始带内容图 - img = Image.open(state.fig_draft_path) - width_px, height_px = img.size - - # --- 检查尺寸并计算缩放 (PPT限制56英寸) --- - max_ppt_inches = 56.0 - dpi = 96.0 - max_pixels = int(max_ppt_inches * dpi) - - scale_ratio = 1.0 - if width_px > max_pixels or height_px > max_pixels: - scale_ratio = max_pixels / max(width_px, height_px) - # 留点余量,避免临界值误差 - scale_ratio *= 0.99 - log.warning(f"[figure_ppt_generation] Image size ({width_px}x{height_px}) exceeds PPT limit. Scaling by {scale_ratio:.4f}") - - width_px = int(width_px * scale_ratio) - height_px = int(height_px * scale_ratio) - - slide_width_px, slide_height_px = setup_presentation_size(prs, width_px, height_px) - - # 空白布局 - blank_slide_layout = prs.slide_layouts[6] - - def _add_layout_emf(slide, item) -> bool: - """ - 将 layout_item 中的 EMF 按像素 bbox 放到 slide 上,坐标逻辑与 add_image_element 保持一致。 - 返回是否成功绘制。 - """ - emf_path = item.get("emf_path") - if not emf_path or not os.path.exists(emf_path): - if emf_path: - log.warning(f"[figure_ppt_generation] emf_path 不存在: {emf_path}") - return False - - # 优先使用像素 bbox_px,其次退回原始 bbox(假定已是像素坐标) - bbox = item.get("bbox_px") or item.get("bbox") - if not bbox or len(bbox) != 4: - log.warning(f"[figure_ppt_generation] layout_item 缺少有效 bbox: {item}") - return False - - x1, y1, x2, y2 = bbox - - # 应用缩放 - if scale_ratio != 1.0: - x1 = int(x1 * scale_ratio) - y1 = int(y1 * scale_ratio) - x2 = int(x2 * scale_ratio) - y2 = int(y2 * scale_ratio) - - if x2 <= x1 or y2 <= y1: - log.warning(f"[figure_ppt_generation] 非法 bbox 像素坐标: {bbox} (scaled)") - return False - - # 像素 → 英寸,完全沿用 utils.pixels_to_inches 的规则 - left_in = pixels_to_inches(x1) - top_in = pixels_to_inches(y1) - width_in = pixels_to_inches(x2 - x1) - height_in = pixels_to_inches(y2 - y1) - - log.info(f"[figure_ppt_generation] 添加 EMF:") - log.info(f" bbox 像素: [{x1}, {y1}, {x2}, {y2}]") - log.info(f" 英寸坐标: left={left_in:.2f}, top={top_in:.2f}, width={width_in:.2f}, height={height_in:.2f}") - log.info(f" emf_path: {emf_path}") - - try: - slide.shapes.add_picture( - emf_path, - Inches(left_in), - Inches(top_in), - Inches(width_in), - Inches(height_in), - ) - return True - except Exception as e: - log.error(f"[figure_ppt_generation] add_picture EMF 失败: {emf_path}, {e}") - return False - - # ========================= - # 第 1 页:完整组合页 - # ========================= - slide_main = prs.slides.add_slide(blank_slide_layout) - - # 白色背景 - background = slide_main.background - fill = background.fill - fill.solid() - fill.fore_color.rgb = RGBColor(255, 255, 255) - - # 1) 先渲染 layout_items (SAM + SVG + EMF 背景层) - layout_drawn = 0 - for item in state.layout_items or []: - if _add_layout_emf(slide_main, item): - layout_drawn += 1 - - # 2) 再渲染 MinerU fig_mask(内容层) - img_drawn = 0 - text_drawn = 0 - for element in state.fig_mask or []: - # 应用缩放 (使用副本以免修改原数据) - element_copy = element.copy() - if scale_ratio != 1.0: - old_bbox = element_copy.get('bbox', [0,0,0,0]) - if len(old_bbox) == 4: - element_copy['bbox'] = [ - int(val * scale_ratio) for val in old_bbox - ] - - elem_type = element_copy.get('type', '') - - if elem_type == 'text': - add_text_element(slide_main, element_copy) - text_drawn += 1 - elif elem_type in ['image', 'table']: - add_image_element(slide_main, element_copy) - img_drawn += 1 - - # ========================= - # 第 2 页:仅 EMF 调试页 - # ========================= - slide_emf = prs.slides.add_slide(blank_slide_layout) - bg2 = slide_emf.background - fill2 = bg2.fill - fill2.solid() - fill2.fore_color.rgb = RGBColor(255, 255, 255) - - layout_debug_drawn = 0 - for item in state.layout_items or []: - if _add_layout_emf(slide_emf, item): - layout_debug_drawn += 1 - - # ========================= - # 第 3 页:full 内容 PNG 原图页(铺满整页) - # ========================= - slide_full = prs.slides.add_slide(blank_slide_layout) - bg3 = slide_full.background - fill3 = bg3.fill - fill3.solid() - fill3.fore_color.rgb = RGBColor(255, 255, 255) - - try: - # 整幅图从 (0,0) 开始铺满整页,坐标/尺寸仍按像素->英寸转换 - left_in = pixels_to_inches(0) - top_in = pixels_to_inches(0) - width_in = pixels_to_inches(width_px) - height_in = pixels_to_inches(height_px) - - slide_full.shapes.add_picture( - state.fig_draft_path, - Inches(left_in), - Inches(top_in), - Inches(width_in), - Inches(height_in), - ) - except Exception as e: - log.error(f"[figure_ppt_generation] add full PNG on page 3 failed: {e}") - - # 保存PPT - prs.save(str(ppt_path)) - state.ppt_path = ppt_path - print(f"PPT generated successfully: {ppt_path}") - print(f"Slide size: {slide_width_px}x{slide_height_px} pixels") - print(f"[MAIN] Total layout items: {len(state.layout_items)}, drawn: {layout_drawn}") - print(f"[MAIN] Total content elements added: {len(state.fig_mask)}, text_drawn={text_drawn}, img_drawn={img_drawn}") - print(f"[EMF_ONLY] layout items drawn: {layout_debug_drawn}") - - except Exception as e: - print(f"Error generating PPT: {e}") - - return state - - # ============================================================== - # 注册 nodes / edges - # ============================================================== - def set_entry_node(state: Paper2FigureState) -> str: - if(state.request.input_type == "PDF"): - log.critical(f'进入PDF node ......') - return "paper_idea_extractor" - elif(state.request.input_type == "TEXT"): - log.critical(f'进入TEXT node ......') - return "figure_desc_generator" - else: - log.error(f"Invalid input type: {state.request.input_type}. Only PDF and TEXT are supported.") - return "_end_" - - def _init_result_path(state: Paper2FigureState) -> Paper2FigureState: - """ - _start_ 节点:确保本次 workflow 有一个统一的 result_path 根目录。 - - 若用户已在 state.result_path 传入自定义目录,则直接使用该目录; - - 若未传入,则初始化为 get_project_root()/outputs/paper2figure/。 - """ - _ensure_result_path(state) - return state - - nodes = { - '_start_': _init_result_path, - "paper_idea_extractor": paper_idea_extractor_node, - "figure_desc_generator": figure_desc_generator_node, - "figure_generator": figure_generator_node, - "figure_layout_sam": figure_layout_sam_node, - "figure_mask_generator": figure_mask_generator_node, - "figure_icon_bg_remover": figure_icon_bg_remover_node, - "figure_ppt_generator": figure_ppt_generation_node, - '_end_': lambda state: state, # 终止节点 - } - - # ------------------------------------------------------------------ - # EDGES (从节点 A 指向节点 B) - # ------------------------------------------------------------------ - edges = [ - ("paper_idea_extractor", "figure_desc_generator"), - ("figure_desc_generator", "figure_generator"), - ("figure_generator", "figure_layout_sam"), - ("figure_layout_sam", "figure_mask_generator"), - ("figure_mask_generator", "figure_icon_bg_remover"), - ("figure_icon_bg_remover", "figure_ppt_generator"), - ("figure_ppt_generator", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges).add_conditional_edge("_start_", set_entry_node) - return builder diff --git a/dataflow_agent/workflow/wf_paper2page_content.py b/dataflow_agent/workflow/wf_paper2page_content.py deleted file mode 100644 index 99c056b..0000000 --- a/dataflow_agent/workflow/wf_paper2page_content.py +++ /dev/null @@ -1,483 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import time -import json -from pathlib import Path -from typing import List, Dict, Any -import re - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.workflow.registry import register -from dataflow_agent.agentroles import create_react_agent, create_simple_agent -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root - -from dataflow_agent.toolkits.multimodaltool.mineru_tool import run_mineru_pdf_extract, _shrink_markdown -from dataflow_agent.toolkits.multimodaltool.req_understanding import call_image_understanding_async - -log = get_logger(__name__) - -def _ensure_result_path(state: Paper2FigureState) -> str: - """ - 参考 wf_paper2figure_with_sam.py 的做法: - 统一本次 paper2page_content workflow 的根输出目录: - - 如果 state.result_path 已存在(通常由调用方传入),直接使用; - - 否则:使用 get_project_root() / "outputs" / "paper2page_content" / , - 并写回 state.result_path,后续节点共享同一目录。 - """ - raw = getattr(state, "result_path", None) - if raw: - return raw - - root = get_project_root() - ts = int(time.time()) - base_dir = (root / "outputs" / "paper2page_content" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - - -def _abs_path(p: str) -> str: - if not p: - return "" - try: - return str(Path(p).expanduser().resolve()) - except Exception: - return p - - -def _find_mineru_auto_dir(paper_dir: Path) -> Path | None: - """ - 探测 MinerU 实际输出的子目录(auto / hybrid_auto 等)。 - """ - candidates = ["auto", "hybrid_auto"] - for name in candidates: - d = paper_dir / name - if d.exists() and d.is_dir(): - return d - for child in sorted(paper_dir.iterdir()): - if child.is_dir() and list(child.glob("*.md")): - return child - return None - - -@register("paper2page_content") -def create_paper2page_content_graph() -> GenericGraphBuilder: # noqa: N802 - """ - Workflow factory: dfa run --wf paper2page_content - """ - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - # ---------------------------------------------------------------------- - # TOOLS (pre_tool definitions) - # ---------------------------------------------------------------------- - @builder.pre_tool("minueru_output", "outline_agent") - def _get_mineru_markdown(state: Paper2FigureState): - return state.minueru_output or "" - - @builder.pre_tool("text_content", "outline_agent") - def _get_text_content(state: Paper2FigureState): - return state.text_content or "" - - @builder.pre_tool("outline_feedback", "outline_refine_agent") - def _get_outline_feedback(state: Paper2FigureState): - return state.outline_feedback or "" - - @builder.pre_tool("minueru_output", "outline_refine_agent") - def _get_mineru_markdown_for_refine(state: Paper2FigureState): - return state.minueru_output or "" - - @builder.pre_tool("text_content", "outline_refine_agent") - def _get_text_content_for_refine(state: Paper2FigureState): - return state.text_content or "" - - @builder.pre_tool("pagecontent", "outline_refine_agent") - def _get_pagecontent_for_refine(state: Paper2FigureState): - return json.dumps(state.pagecontent or [], ensure_ascii=False) - - @builder.pre_tool("pagecontent_raw", "outline_refine_agent") - def _get_pagecontent_raw_for_refine(state: Paper2FigureState): - return state.pagecontent or [] - - # ---------- image pipeline pre_tools ---------- - @builder.pre_tool("image_items_json", "image_filter_agent") - def _get_image_items_json_for_filter(state: Paper2FigureState): - return json.dumps(getattr(state, "image_items", []) or [], ensure_ascii=False) - - @builder.pre_tool("query", "image_filter_agent") - def _get_query_for_filter(state: Paper2FigureState): - return getattr(state, "kb_query", "") or "" - - @builder.pre_tool("pagecontent_json", "kb_image_insert_agent") - def _get_pagecontent_json_for_insert(state: Paper2FigureState): - return json.dumps(state.pagecontent or [], ensure_ascii=False) - - @builder.pre_tool("image_items_json", "kb_image_insert_agent") - def _get_image_items_json_for_insert(state: Paper2FigureState): - return json.dumps(getattr(state, "filtered_image_items", []) or [], ensure_ascii=False) - - # ============================================================== - # NODES - # ============================================================== - def _start_(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - state.minueru_output = state.minueru_output or "" - state.text_content = state.text_content or "" - state.pagecontent = state.pagecontent or [] - state.outline_feedback = state.outline_feedback or "" - state.image_items = getattr(state, "image_items", []) or [] - state.filtered_image_items = getattr(state, "filtered_image_items", []) or [] - return state - - async def parse_pdf_pages(state: Paper2FigureState) -> Paper2FigureState: - """ - PDF: MinerU 解析 -> 读取 markdown 全文 -> 写入 state.minueru_output - - 目录约定(与 MinerU 实际行为对齐): - - 传入的输出根目录为 result_root = state.result_path - - MinerU 会在其下创建: - /auto/.md - /auto/images/*.jpg - - 我们将 state.mineru_root 指向实际承载 md 和 images 的 auto 目录, - 这样后续 asset_ref="images/xxx.jpg" 能解析到正确路径。 - """ - paper_pdf_path = Path(_abs_path(state.paper_file)) - if not paper_pdf_path.exists(): - log.error(f"[paper2page_content] PDF 文件不存在: {paper_pdf_path}") - state.minueru_output = "" - return state - - # 统一本次 workflow 的根输出目录 - result_root = Path(_ensure_result_path(state)) - result_root.mkdir(parents=True, exist_ok=True) - - pdf_stem = paper_pdf_path.stem - paper_dir = result_root / pdf_stem - - # 探测已有的 MinerU 输出目录(auto / hybrid_auto 等) - auto_dir = _find_mineru_auto_dir(paper_dir) if paper_dir.exists() else None - - if auto_dir is None: - try: - run_mineru_pdf_extract(str(paper_pdf_path), str(result_root), "modelscope") - except Exception as e: - log.error(f"[paper2page_content] run_mineru_pdf_extract 失败: {e}") - state.minueru_output = "" - return state - auto_dir = _find_mineru_auto_dir(paper_dir) - - if auto_dir is None: - log.error(f"[paper2page_content] MinerU 输出目录不存在: {paper_dir}") - state.minueru_output = "" - return state - - auto_dir = auto_dir.resolve() - markdown_path = auto_dir / f"{pdf_stem}.md" - if not markdown_path.exists(): - md_files = list(auto_dir.glob("*.md")) - markdown_path = md_files[0] if md_files else markdown_path - if not markdown_path.exists(): - log.error(f"[paper2page_content] Markdown 文件不存在: {markdown_path}") - state.minueru_output = "" - return state - - try: - md = markdown_path.read_text(encoding="utf-8") - except Exception as e: - log.error(f"[paper2page_content] 读取 markdown 失败: {markdown_path}, err={e}") - md = "" - # Avoid passing overly long markdown to downstream agents. - state.minueru_output = _shrink_markdown(md, max_h1=8, max_chars=30_000) - # 记录 MinerU 输出根目录 = 实际承载 md 与 images 的 auto 目录 - state.mineru_root = str(auto_dir) - log.info(f"[paper2page_content] minueru_output : {state.minueru_output[:100]} ") - return state - - async def prepare_text_input(state: Paper2FigureState) -> Paper2FigureState: - """ - TEXT: 直接进入 outline agent 前,把文本放到 state.text_content - """ - # 兼容:优先 paper2ppt 专用 text_content;如果外部通过 request.target 传入文本,也做兜底 - if not state.text_content: - state.text_content = getattr(state.request, "target", "") or "" - return state - - async def ppt_to_images(state: Paper2FigureState) -> Paper2FigureState: - """ - PPT/PPTX: 转成每页图片,写入 state.pagecontent: - [{"ppt_img_path": "/abs/slide_001.png"}, ...] - 注意:这里的 pagecontent 仅作为 outline agent 的输入材料,最终 pagecontent 会被 agent 改写。 - """ - ppt_path = Path(_abs_path(state.paper_file)) - if not ppt_path.exists(): - log.error(f"[paper2page_content] PPT 文件不存在: {ppt_path}") - state.pagecontent = [] - return state - - output_dir = Path(_ensure_result_path(state)) / "ppt_images" - output_dir.mkdir(parents=True, exist_ok=True) - - # 策略:优先 soffice 转 pdf,再 pdf2image - pdf_path = output_dir / f"{ppt_path.stem}.pdf" - if not pdf_path.exists(): - cmd = ( - f'soffice --headless --convert-to pdf --outdir "{output_dir}" "{ppt_path}"' - ) - # 这里不能用 execute_command 工具(在 workflow runtime 内执行),因此用 os.system 兜底; - - ret = os.system(cmd) - if ret != 0: - log.error( - f"[paper2page_content] soffice 转 pdf 失败(ret={ret}). " - f"请确认部署机器安装了 libreoffice/soffice。cmd={cmd}" - ) - state.pagecontent = [] - return state - - if not pdf_path.exists(): - log.error(f"[paper2page_content] soffice 转出的 pdf 不存在: {pdf_path}") - state.pagecontent = [] - return state - - try: - from pdf2image import convert_from_path - except Exception as e: - log.error(f"[paper2page_content] 缺少 pdf2image 依赖,无法将 pdf 转图片: {e}") - state.pagecontent = [] - return state - - try: - slide_imgs = convert_from_path(str(pdf_path)) - except Exception as e: - log.error(f"[paper2page_content] pdf2image 转换失败: {e}") - state.pagecontent = [] - return state - - page_items: List[Dict[str, Any]] = [] - for i, img in enumerate(slide_imgs): - img_path = output_dir / f"slide_{i:03d}.png" - try: - img.save(img_path, "PNG") - except Exception as e: - log.error(f"[paper2page_content] 保存 slide png 失败: {img_path}, err={e}") - continue - page_items.append({"ppt_img_path": str(img_path.resolve())}) - - state.pagecontent = page_items - return state - - async def outline_agent(state: Paper2FigureState) -> Paper2FigureState: - """ - Outline agent 骨架:你后续实现 agent 逻辑,产出 state.pagecontent(list[dict])。 - 这里仅负责创建并执行 agent,然后返回 state。 - """ - agent = create_react_agent( - name="outline_agent", - temperature=0.1, - max_retries=5, - parser_type="json", - ) - state = await agent.execute(state=state) - return state - - async def outline_refine_agent(state: Paper2FigureState) -> Paper2FigureState: - """ - outline_refine_agent: refine existing outline based on user feedback. - """ - agent = create_react_agent( - name="outline_refine_agent", - parser_type="json", - max_retries=5 - ) - state = await agent.execute(state=state) - return state - - async def deep_research_agent(state: Paper2FigureState) -> Paper2FigureState: - """ - Deep Research Agent: 接收 Topic,生成长文,更新 state.text_content - """ - log.info("[paper2page_content] Entering deep_research_agent...") - agent = create_simple_agent( - name="deep_research_agent", - temperature=0.7, - parser_type="text", # 直接输出长文本 - ) - state = await agent.execute(state=state) - return state - - # ---------- image pipeline nodes (mirroring kb_page_content) ---------- - - async def extract_md_images(state: Paper2FigureState) -> Paper2FigureState: - """从 MinerU 输出的 markdown 中提取图片引用路径。""" - mineru_root = getattr(state, "mineru_root", "") or "" - image_paths: List[str] = [] - if mineru_root: - try: - md_files = list(Path(mineru_root).glob("*.md")) - md_text = md_files[0].read_text(encoding="utf-8") if md_files else "" - except Exception as e: - log.error(f"[paper2page_content] 读取 md 失败: {e}") - md_text = "" - - if md_text: - md_imgs = re.findall(r"!\[[^\]]*\]\(([^)]+)\)", md_text) - html_imgs = re.findall(r"]+src=[\"']([^\"']+)[\"']", md_text) - for rel in md_imgs + html_imgs: - rel = rel.strip() - if not rel: - continue - img_path = Path(mineru_root) / rel - if img_path.exists(): - image_paths.append(str(img_path.resolve())) - - state.kb_md_images = list(dict.fromkeys(image_paths)) - log.info("[paper2page_content] extract_md_images: found %s images", len(state.kb_md_images)) - return state - - async def caption_images(state: Paper2FigureState) -> Paper2FigureState: - """合并 MinerU 提取图片与用户图片,并行补全 caption。""" - user_images = getattr(state, "kb_user_images", []) or [] - md_images = getattr(state, "kb_md_images", []) or [] - items: List[Dict[str, Any]] = [] - - for p in md_images: - items.append({"path": p, "caption": "", "source": "mineru"}) - for item in user_images: - path = item.get("path") or item.get("url") or "" - if not path: - continue - caption = item.get("description") or item.get("caption") or "" - items.append({"path": path, "caption": caption, "source": "user"}) - - # 去重 - unique = {} - for it in items: - unique[it["path"]] = it - items = list(unique.values()) - - async def _caption_one(it: Dict[str, Any]) -> Dict[str, Any]: - if it.get("caption"): - return it - try: - desc = await call_image_understanding_async( - model=getattr(state.request, "vlm_model", "gemini-2.5-flash"), - messages=[{"role": "user", "content": "Please provide a concise caption for this image for PPT slide selection."}], - api_url=state.request.chat_api_url, - api_key=state.request.chat_api_key or state.request.api_key, - image_path=it.get("path"), - ) - it["caption"] = desc.strip() - except Exception as e: - log.error(f"[paper2page_content] caption failed: {e}") - return it - - tasks = [_caption_one(it) for it in items] - if tasks: - items = list(await asyncio.gather(*tasks)) - - state.image_items = items - log.info("[paper2page_content] caption_images: %s items", len(items)) - return state - - async def filter_images_agent(state: Paper2FigureState) -> Paper2FigureState: - """按 query 筛选相关图片;无 query 则全部保留。""" - query = (getattr(state, "kb_query", "") or "").strip() - if not state.image_items: - state.filtered_image_items = [] - return state - if not query: - state.filtered_image_items = list(state.image_items) - return state - - agent = create_react_agent( - name="image_filter_agent", - temperature=0.1, - max_retries=3, - parser_type="json", - ) - state = await agent.execute(state=state) - if not getattr(state, "filtered_image_items", None): - state.filtered_image_items = list(state.image_items) - return state - - async def insert_images_agent(state: Paper2FigureState) -> Paper2FigureState: - """将筛选后的图片作为独立页面插入 pagecontent。""" - if not getattr(state, "filtered_image_items", None): - return state - agent = create_react_agent( - name="kb_image_insert_agent", - temperature=0.2, - max_retries=3, - parser_type="json", - ) - state = await agent.execute(state=state) - return state - - # ============================================================== - # 注册 nodes / edges - # ============================================================== - def _route_input(state: Paper2FigureState) -> str: - feedback = (state.outline_feedback or "").strip() - if feedback and state.pagecontent: - log.critical("走 OUTLINE 反馈修订路径") - return "outline_refine_agent" - t = getattr(state.request, "input_type", None) or getattr(state, "input_type", None) or "" - t = str(t).upper().strip() - if t == "PDF": - log.critical("走 PDF 路径") - return "parse_pdf_pages" - if t == "TEXT": - log.critical("走 TEXT 路径") - return "prepare_text_input" - if t == "TOPIC": - log.critical("走 TOPIC 路径 (Deep Research)") - return "deep_research_agent" - if t in ["PPT", "PPTX"]: - log.critical("走 PPT 路径") - return "ppt_to_images" - log.error(f"[paper2page_content] Invalid input_type: {t}") - return "_end_" - - def _route_after_outline(state: Paper2FigureState) -> str: - """outline_agent 完成后,判断是否有图片需要处理。""" - mineru_root = getattr(state, "mineru_root", "") or "" - user_images = getattr(state, "kb_user_images", []) or [] - if mineru_root or user_images: - return "extract_md_images" - return "_end_" - - nodes = { - "_start_": _start_, - "parse_pdf_pages": parse_pdf_pages, - "prepare_text_input": prepare_text_input, - "ppt_to_images": ppt_to_images, - "deep_research_agent": deep_research_agent, - "outline_agent": outline_agent, - "outline_refine_agent": outline_refine_agent, - "extract_md_images": extract_md_images, - "caption_images": caption_images, - "filter_images_agent": filter_images_agent, - "insert_images_agent": insert_images_agent, - "_end_": lambda state: state, - } - - edges = [ - ("parse_pdf_pages", "outline_agent"), - ("prepare_text_input", "outline_agent"), - ("deep_research_agent", "outline_agent"), - ("ppt_to_images", "_end_"), - ("outline_refine_agent", "_end_"), - # image pipeline chain - ("extract_md_images", "caption_images"), - ("caption_images", "filter_images_agent"), - ("filter_images_agent", "insert_images_agent"), - ("insert_images_agent", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges) - builder.add_conditional_edge("_start_", _route_input) - builder.add_conditional_edge("outline_agent", _route_after_outline) - return builder diff --git a/dataflow_agent/workflow/wf_paper2page_content_for_long_paper.py b/dataflow_agent/workflow/wf_paper2page_content_for_long_paper.py deleted file mode 100644 index d0a236a..0000000 --- a/dataflow_agent/workflow/wf_paper2page_content_for_long_paper.py +++ /dev/null @@ -1,651 +0,0 @@ -from __future__ import annotations - -import os -import time -import copy -from pathlib import Path -from typing import List, Dict, Any, Tuple -import re - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.workflow.registry import register -from dataflow_agent.agentroles import create_react_agent, create_simple_agent -from dataflow_agent.agentroles.paper2any_agents.long_paper_outline_agent import create_long_paper_outline_agent -from dataflow_agent.agentroles.paper2any_agents.content_expander_agent import create_content_expander -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root - -from dataflow_agent.toolkits.multimodaltool.mineru_tool import run_mineru_pdf_extract - -log = get_logger(__name__) - -""" -Workflow: paper2page_content_for_long_paper -Description: 专门用于处理长文档(如书籍、长论文、长篇报告)生成大量 PPT 页面的工作流。 - -Process: -1. Input Routing (_start_ -> _route_input): - - PDF: 解析 PDF 获取全文 markdown (parse_pdf_pages_long) - - TEXT: 直接接收文本输入 (prepare_text_input) - - TOPIC: 根据主题生成长文 (generate_long_content_from_topic) - -2. Content Expansion & Consolidation: - - 对于 TEXT/TOPIC 输入,如果内容不足,会进行迭代扩写 (expand_text_iteratively / generate_long_content_from_topic)。 - - 所有来源的内容最终汇总到 state.long_text (consolidate_long_text)。 - - 再次检查总长度,如果不足目标页数所需字符数,进行补充扩写 (ensure_sufficient_content)。 - * 动态字符数计算:英文 ~3000 chars/page, 中文 ~800 chars/page。 - -3. Outline Generation (outline_for_long_text): - - 根据 state.request.page_count (默认为 60) 和总文本长度,计算分批方案。 - - 将长文本切分为多个 batch。 - - 对每个 batch 调用 long_paper_outline_agent 生成对应页面的 outline (generate_outline_for_batch)。 - - 汇总所有批次的页面内容,并进行首尾衔接处理。 - -4. Output: - - 生成的页面列表存储在 state.pagecontent。 -""" - -# ============================================================ -# 辅助函数 -# ============================================================ - -def _ensure_result_path(state: Paper2FigureState) -> str: - """ - 统一本次 workflow 的根输出目录 - """ - raw = getattr(state, "result_path", None) - if raw: - return raw - - root = get_project_root() - ts = int(time.time()) - base_dir = (root / "outputs" / "paper2page_content_long" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - - -def _abs_path(p: str) -> str: - if not p: - return "" - try: - return str(Path(p).expanduser().resolve()) - except Exception: - return p - - -def _is_english_text(text: str | Any) -> bool: - """简单判断文本是否主要为英文(ASCII占比 > 80%)""" - if not text: - # 默认非英文(中文)以保持较低的字符阈值,避免误判导致过度扩写 - return False - - if not isinstance(text, str): - try: - text = str(text) - except Exception: - return False - - # 统计前 5000 个字符即可 - sample = text[:5000] - ascii_count = sum(1 for c in sample if ord(c) < 128) - return (ascii_count / len(sample)) > 0.8 - - -def _calculate_target_chars(target_pages: int, text: str = "") -> int: - """ - 根据页数和语言类型计算目标字符数 - 英文:约 3000 chars/page - 中文:约 800 chars/page - """ - is_en = _is_english_text(text) - chars_per_page = 3000 if is_en else 800 - target = target_pages * chars_per_page - # log.info(f"[long_paper] 目标计算: {target_pages}页, 英文={is_en}, 阈值={target} chars") - return target - - -def split_text_by_chars(text: str, chunk_size: int = 30000) -> List[str]: - """ - 按字符数切分文本,尽量在段落边界切分 - """ - if len(text) <= chunk_size: - return [text] - - chunks = [] - current_pos = 0 - - while current_pos < len(text): - end_pos = min(current_pos + chunk_size, len(text)) - - # 向后查找段落边界(双换行符),但不超过500字符 - if end_pos < len(text): - boundary = text.rfind('\n\n', current_pos, end_pos + 500) - if boundary > current_pos: - end_pos = boundary - - chunks.append(text[current_pos:end_pos]) - current_pos = end_pos - - return chunks - - -def calculate_batches( - total_chars: int, - target_pages: int, - pages_per_batch: int = 10 -) -> List[Tuple[int, int, int, bool, bool]]: - """ - 计算分批方案 - - Args: - total_chars: 总字符数 - target_pages: 目标总页数 - pages_per_batch: 每批次目标页数 - - Returns: - [(start_char, end_char, batch_idx, is_first, is_last), ...] - """ - num_batches = max(1, (target_pages + pages_per_batch - 1) // pages_per_batch) - chars_per_batch = total_chars // num_batches - - batches = [] - for i in range(num_batches): - start_char = i * chars_per_batch - end_char = min((i + 1) * chars_per_batch, total_chars) - is_first = (i == 0) - is_last = (i == num_batches - 1) - batches.append((start_char, end_char, i, is_first, is_last)) - - return batches - - -# ============================================================ -# Workflow 工厂函数 -# ============================================================ - -@register("paper2page_content_for_long_paper") -def create_paper2page_content_graph() -> GenericGraphBuilder: - """ - 长文本 Paper2PageContent Workflow - 专门处理长文本(50页+)的 PDF/TEXT/TOPIC 输入 - """ - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - # ---------------------------------------------------------------------- - # PRE-TOOLS - # ---------------------------------------------------------------------- - - @builder.pre_tool("current_chunk", "long_paper_outline_agent") - def _get_current_chunk(state: Paper2FigureState): - """提供当前批次的文本内容""" - return getattr(state, "current_chunk", "") - - @builder.pre_tool("batch_info", "long_paper_outline_agent") - def _get_batch_info(state: Paper2FigureState): - """提供批次信息,用于 prompt 生成""" - idx = getattr(state, "chunk_index", 0) - total = getattr(state, "total_chunks", 1) - pages = getattr(state, "pages_to_generate", 10) - return { - "batch_index": idx + 1, - "total_batches": total, - "pages_to_generate": pages, - "is_first": idx == 0, - "is_last": idx == total - 1, - } - @builder.pre_tool("generation_round", "topic_writer") - def _get_generation_round(state: Paper2FigureState): - """提供 topic 生成轮次信息""" - return getattr(state, "generation_round", 0) - - # ---------------------------------------------------------------------- - # Outline Refine Tools (Added for consistency with standard workflow) - # ---------------------------------------------------------------------- - @builder.pre_tool("outline_feedback", "outline_refine_agent") - def _get_outline_feedback(state: Paper2FigureState): - return state.outline_feedback or "" - - @builder.pre_tool("minueru_output", "outline_refine_agent") - def _get_mineru_markdown_for_refine(state: Paper2FigureState): - return state.minueru_output or "" - - @builder.pre_tool("text_content", "outline_refine_agent") - def _get_text_content_for_refine(state: Paper2FigureState): - return state.text_content or "" - - @builder.pre_tool("pagecontent", "outline_refine_agent") - def _get_pagecontent_for_refine(state: Paper2FigureState): - return json.dumps(state.pagecontent or [], ensure_ascii=False) - - @builder.pre_tool("pagecontent_raw", "outline_refine_agent") - def _get_pagecontent_raw_for_refine(state: Paper2FigureState): - return state.pagecontent or [] - - # ============================================================== - # NODES - # ============================================================== - - def _start_(state: Paper2FigureState) -> Paper2FigureState: - """初始化 state""" - _ensure_result_path(state) - - # 初始化字段 - state.minueru_output = state.minueru_output or "" - state.text_content = state.text_content or "" - state.pagecontent = state.pagecontent or [] - state.long_text = getattr(state, "long_text", "") or "" - - # 设置默认目标页数 - # 1. 优先从 request.page_count 获取 - if state.request and state.request.page_count: - state.target_pages = state.request.page_count - # 2. 否则查看 state 中是否有 target_pages - elif not hasattr(state, "target_pages") or not state.target_pages: - state.target_pages = 60 # 默认 60 页 - - log.info(f"[long_paper] 目标页数: {state.target_pages}") - return state - - async def parse_pdf_pages_long(state: Paper2FigureState) -> Paper2FigureState: - """ - PDF 长文解析:读取完整 markdown,不做字符限制 - """ - paper_pdf_path = Path(_abs_path(state.paper_file)) - if not paper_pdf_path.exists(): - log.error(f"[long_paper] PDF 文件不存在: {paper_pdf_path}") - state.long_text = "" - return state - - result_root = Path(_ensure_result_path(state)) - result_root.mkdir(parents=True, exist_ok=True) - - pdf_stem = paper_pdf_path.stem - paper_dir = result_root / pdf_stem - auto_dir = paper_dir / "auto" - - # 触发 MinerU 解析 - if not auto_dir.exists(): - try: - log.info(f"[long_paper] 开始 MinerU 解析: {paper_pdf_path}") - run_mineru_pdf_extract(str(paper_pdf_path), str(result_root), "modelscope") - except Exception as e: - log.error(f"[long_paper] MinerU 解析失败: {e}") - state.long_text = "" - return state - - auto_dir = (result_root / pdf_stem / "auto").resolve() - markdown_path = auto_dir / f"{pdf_stem}.md" - - if not markdown_path.exists(): - log.error(f"[long_paper] Markdown 文件不存在: {markdown_path}") - state.long_text = "" - return state - - try: - md = markdown_path.read_text(encoding="utf-8") - log.info(f"[long_paper] 读取完整 markdown: {len(md)} 字符") - except Exception as e: - log.error(f"[long_paper] 读取 markdown 失败: {e}") - md = "" - - # 不做裁剪,保留完整内容 - state.long_text = md - state.mineru_root = str(auto_dir) - - return state - - async def prepare_text_input(state: Paper2FigureState) -> Paper2FigureState: - """ - TEXT 输入:准备文本内容 - """ - log.info(f"[long_paper] TEXT 输入长度: {len(state.text_content)} 字符") - return state - - async def expand_text_iteratively(state: Paper2FigureState) -> Paper2FigureState: - """ - TEXT 循环扩写:扩写到足够长度 - """ - target_pages = getattr(state, "target_pages", 60) - current_text = state.text_content or "" - - # 动态计算目标 - target_chars = _calculate_target_chars(target_pages, current_text) - - log.info(f"[long_paper] 开始扩写,当前: {len(current_text)} 字符,目标: {target_chars} 字符 ({target_pages}页)") - - if len(current_text) >= target_chars: - log.info(f"[long_paper] 初始长度已满足要求") - return state - - max_rounds = state.max_rounds - - agent = create_simple_agent( - name = "content_expander", - temperature=0.7, - parser_type="text", - ) - - for round_num in range(max_rounds): - state.expansion_round = round_num - state.text_content = current_text - - state = await agent.execute(state=state) - - # 增加类型检查,防止 agent 返回 dict 导致后续切片报错 - # 用户要求:直接把字典当字符串 - current_text = str(state.text_content) if state.text_content else "" - - # 重新计算目标(以防语言变化) - target_chars = _calculate_target_chars(target_pages, current_text) - - log.info(f"[long_paper] 扩写轮次 {round_num + 1}/{max_rounds}: {len(current_text)} / {target_chars} 字符") - - if len(current_text) >= target_chars: - log.info(f"[long_paper] 扩写完成,达到目标长度") - break - - state.text_content = current_text - return state - - async def generate_long_content_from_topic(state: Paper2FigureState) -> Paper2FigureState: - """ - TOPIC 多轮生成长文 - """ - target_pages = getattr(state, "target_pages", 60) - max_rounds = state.max_rounds - - current_text = state.text_content or "" - target_chars = target_pages * 800 - - log.info(f"[long_paper] 从 TOPIC 生成长文,当前: {len(current_text)} 字符") - agent = create_simple_agent( - name="topic_writer", - parser_type="text", - ) - for round_num in range(max_rounds): - state.generation_round = round_num - state.text_content = current_text - - state = await agent.execute(state=state) - - current_text = str(state.text_content) if state.text_content else "" - - # 动态更新目标 - target_chars = _calculate_target_chars(target_pages, current_text) - log.info(f"[long_paper] 生成轮次 {round_num + 1}/{max_rounds}: {len(current_text)} / {target_chars} 字符") - if len(current_text) >= target_chars: - log.info(f"[long_paper] 生成完成,达到目标长度") - break - state.text_content = current_text - return state - - async def outline_refine_agent(state: Paper2FigureState) -> Paper2FigureState: - """ - outline_refine_agent: refine existing outline based on user feedback. - """ - agent = create_react_agent( - name="outline_refine_agent", - parser_type="json", - max_retries=5 - ) - state = await agent.execute(state=state) - return state - - async def consolidate_long_text(state: Paper2FigureState) -> Paper2FigureState: - """ - 统一整合各来源的长文本到 state.long_text - """ - if state.long_text: - # PDF 路径已经有 long_text - log.info(f"[long_paper] 使用 PDF markdown: {len(state.long_text)} 字符") - elif state.text_content: - # TEXT/TOPIC 路径使用 text_content - state.long_text = state.text_content - log.info(f"[long_paper] 使用 text_content: {len(state.long_text)} 字符") - else: - state.long_text = "" - log.warning("[long_paper] 没有可用的长文本内容") - - return state - - async def ensure_sufficient_content(state: Paper2FigureState) -> Paper2FigureState: - """ - 确保内容足够长,不够则扩写 - """ - target_pages = getattr(state, "target_pages", 60) - long_text = state.long_text or "" - - # 动态计算目标 - target_chars = _calculate_target_chars(target_pages, long_text) - - if len(long_text) >= target_chars: - log.info(f"[long_paper] 内容充足: {len(long_text)} >= {target_chars} 字符") - return state - - log.info(f"[long_paper] 内容不足({len(long_text)} < {target_chars} chars),开始补充扩写") - - agent = create_content_expander( - temperature=0.7, - parser_type="text", - ) - - max_rounds = state.max_rounds - current_text = long_text - - for round_num in range(max_rounds): - state.expansion_round = round_num - state.text_content = current_text - - state = await agent.execute(state=state) - - # 增加类型检查 - # 用户要求:直接把字典当字符串 - current_text = str(state.text_content) if state.text_content else "" - - # 重新计算目标 - target_chars = _calculate_target_chars(target_pages, current_text) - - log.info(f"[long_paper] 补充扩写轮次 {round_num + 1}/{max_rounds}: {len(current_text)} / {target_chars} 字符") - - if len(current_text) >= target_chars: - break - - state.long_text = current_text - log.info(f"[long_paper] 最终扩写后长度: {len(state.long_text)} 字符") - return state - - async def generate_outline_for_batch( - state: Paper2FigureState, - chunk_text: str, - batch_idx: int, - total_batches: int, - pages_to_generate: int = 12, - ) -> List[Dict[str, Any]]: - """ - 为单个批次生成 outline - """ - # 深拷贝 state 以防止并发修改冲突 - state = copy.deepcopy(state) - - log.critical(f"[chunk_text: ] {chunk_text[:200]}") - - # 临时设置当前批次信息 - state.current_chunk = chunk_text - state.chunk_index = batch_idx - state.total_chunks = total_batches - state.pages_to_generate = pages_to_generate - - # 显式设置首尾状态,供 Agent 动态选择 Prompt - state.is_first = (batch_idx == 0) - state.is_last = (batch_idx == total_batches - 1) - - # 调用 long_paper_outline_agent - agent = create_react_agent( - name = "long_paper_outline_agent", - temperature=0.1, - max_retries=5, - parser_type="json", - ) - - result_state = await agent.execute(state=state) - - # 提取生成的页面 - pages = result_state.pagecontent or [] - if not isinstance(pages, list): - pages = [pages] - - log.info(f"[long_paper] 批次 {batch_idx + 1}/{total_batches} 生成了 {len(pages)} 页") - return pages - - async def outline_for_long_text(state: Paper2FigureState) -> Paper2FigureState: - """ - 对长文本按目标页数分批生成 outline(并行处理) - """ - import asyncio - - long_text = state.long_text or "" - target_pages = getattr(state, "target_pages", 60) - pages_per_batch = state.pages_per_batch # 每批次目标页数 - pages_to_generate = state.pages_to_generate # 每批次让 agent 生成的页数(含首尾) - - if not long_text: - log.error("[long_paper] 没有长文本内容,无法生成 outline") - state.pagecontent = [] - return state - - # 1. 确保内容充足 - target_chars = _calculate_target_chars(target_pages, long_text) - if len(long_text) < target_chars: - log.info(f"[long_paper] 内容不足({len(long_text)} < {target_chars}),触发扩写") - state = await ensure_sufficient_content(state) - long_text = state.long_text - - # 2. 计算分批方案 - batches = calculate_batches(len(long_text), target_pages, pages_per_batch) - log.info(f"[long_paper] 分 {len(batches)} 批次,目标 {target_pages} 页,将并行处理") - - # 3. 并行处理所有批次 - tasks = [] - batch_info = [] # 保存批次信息用于后续处理 - - for start_char, end_char, batch_idx, is_first, is_last in batches: - chunk_text = long_text[start_char:end_char] - - log.info(f"[long_paper] 准备批次 {batch_idx + 1}/{len(batches)}: " - f"字符 {start_char}-{end_char} ({len(chunk_text)} chars)") - - # 创建异步任务 - task = generate_outline_for_batch( - state=state, - chunk_text=chunk_text, - batch_idx=batch_idx, - total_batches=len(batches), - pages_to_generate=pages_to_generate, - ) - tasks.append(task) - batch_info.append((batch_idx, is_first, is_last)) - - # 4. 并行执行所有任务 - log.info(f"[long_paper] 开始并行执行 {len(tasks)} 个批次...") - results = await asyncio.gather(*tasks) - log.info(f"[long_paper] 并行执行完成,收到 {len(results)} 个结果") - - # 5. 按顺序处理结果 - all_pages = [] - for idx, (chunk_pages, (batch_idx, is_first, is_last)) in enumerate(zip(results, batch_info)): - # 不再进行裁剪,直接保留所有生成的页面 - selected = chunk_pages - log.info(f"[long_paper] 批次 {batch_idx + 1}: 生成 {len(chunk_pages)} 页,全部保留") - all_pages.extend(selected) - - # 6. 确保总页数符合要求 - if len(all_pages) > target_pages: - log.warning(f"[long_paper] 生成页数超出目标({len(all_pages)} > {target_pages}),截断") - all_pages = all_pages[:target_pages] - elif len(all_pages) < target_pages: - log.warning(f"[long_paper] 生成页数不足: {len(all_pages)}/{target_pages}") - - state.pagecontent = all_pages - log.info(f"[long_paper] 并行处理完成,最终生成 {len(all_pages)} 页 pagecontent") - - return state - - # ============================================================== - # 路由函数 - # ============================================================== - - def _route_input(state: Paper2FigureState) -> str: - """根据输入类型路由到不同节点""" - # 优先检查是否有反馈 - feedback = (state.outline_feedback or "").strip() - if feedback and state.pagecontent: - log.critical("走 OUTLINE 反馈修订路径 (Long Paper)") - return "outline_refine_agent" - - t = getattr(state.request, "input_type", None) or getattr(state, "input_type", None) or "" - t = str(t).upper().strip() - - if t == "PDF": - log.info("[long_paper] 路由: PDF → parse_pdf_pages_long") - return "parse_pdf_pages_long" - elif t == "TEXT": - log.info("[long_paper] 路由: TEXT → prepare_text_input") - return "prepare_text_input" - elif t == "TOPIC": - log.info("[long_paper] 路由: TOPIC → generate_long_content_from_topic") - return "generate_long_content_from_topic" - else: - log.error(f"[long_paper] 无效的 input_type: {t},仅支持 PDF/TEXT/TOPIC") - return "_end_" - - # ============================================================== - # 注册 nodes / edges - # ============================================================== - - nodes = { - "_start_": _start_, - - # PDF 路径 - "parse_pdf_pages_long": parse_pdf_pages_long, - - # TEXT 路径 - "prepare_text_input": prepare_text_input, - "expand_text_iteratively": expand_text_iteratively, - - # TOPIC 路径 - "generate_long_content_from_topic": generate_long_content_from_topic, - - # 统一处理 - "consolidate_long_text": consolidate_long_text, - "outline_for_long_text": outline_for_long_text, - - # 修订 - "outline_refine_agent": outline_refine_agent, - - "_end_": lambda state: state, - } - - edges = [ - # Refine → End - ("outline_refine_agent", "_end_"), - - # PDF → 统一整合 - ("parse_pdf_pages_long", "consolidate_long_text"), - - # TEXT → 扩写 → 统一整合 - ("prepare_text_input", "expand_text_iteratively"), - ("expand_text_iteratively", "consolidate_long_text"), - - # TOPIC → 生成 → 统一整合 - ("generate_long_content_from_topic", "consolidate_long_text"), - - # 统一整合 → 分批 outline → 结束 - ("consolidate_long_text", "outline_for_long_text"), - ("outline_for_long_text", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges).add_conditional_edge("_start_", _route_input) - - return builder diff --git a/dataflow_agent/workflow/wf_paper2ppt_parallel.py b/dataflow_agent/workflow/wf_paper2ppt_parallel.py deleted file mode 100644 index afa951a..0000000 --- a/dataflow_agent/workflow/wf_paper2ppt_parallel.py +++ /dev/null @@ -1,628 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import os -import time -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.utils import get_project_root -from dataflow_agent.workflow.registry import register -from dataflow_agent.agentroles import create_react_agent -from dataflow_agent.toolkits.multimodaltool.req_img import generate_or_edit_and_save_image_async -from dataflow_agent.toolkits.multimodaltool.ppt_tool import convert_images_dir_to_pdf_and_ppt, convert_images_dir_to_pdf_and_ppt_api - -log = get_logger(__name__) - - -def _ensure_result_path(state: Paper2FigureState) -> str: - """ - 统一 paper2ppt workflow 的根输出目录: - - 若 state.result_path 已存在(通常由调用方传入),直接使用; - - 否则:使用 get_project_root()/outputs/paper2ppt/ 初始化,并写回 state.result_path。 - """ - raw = getattr(state, "result_path", None) - if raw: - return raw - - root = get_project_root() - ts = int(time.time()) - base_dir = (root / "outputs" / "paper2ppt" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - - -def _abs_path(p: str) -> str: - if not p: - return "" - try: - return str(Path(p).expanduser().resolve()) - except Exception: - return p - - -def _is_table_asset(asset_ref: Optional[str]) -> bool: - """ - 你给的约定:asset 是 Table 时,通过 asset_ref: "Table 2" 这种字符串标记。 - """ - if not asset_ref: - return False - s = str(asset_ref).strip().lower() - return s.startswith("table") - - -def _serialize_prompt_dict(d: Dict[str, Any]) -> str: - """ - 把 dict 安全序列化为 prompt 文本(中文不转义)。 - """ - try: - return json.dumps(d, ensure_ascii=False, indent=2) - except Exception: - # 兜底:不要因为序列化失败而中断 - return str(d) - - -def _normalize_single_asset_ref(asset_ref: str) -> str: - """ - 规范化 asset_ref,仅保留第一张图的路径/文件名。 - - 当前版本不支持多图编辑: - - 如果 asset_ref 中包含逗号等分隔符,如 "a.jpg,b.jpg", - 只取第一段 "a.jpg"。 - - TODO: 后续可以扩展多图 asset_ref 支持。 - """ - if not asset_ref: - return "" - s = str(asset_ref).strip() - if not s: - return "" - - # 简单按逗号切分,保留第一个 - parts = [p.strip() for p in s.split(",") if p.strip()] - if not parts: - return "" - - if len(parts) > 1: - log.warning( - "[paper2ppt] asset_ref 包含多张图片,仅使用第一张。" - f" raw={asset_ref!r}, first={parts[0]!r} # TODO: 支持多图 asset_ref" - ) - - return parts[0] - - -async def _make_prompt_for_structured_page(item: Dict[str, Any], style: str, state: Paper2FigureState) -> Tuple[str, Optional[str], bool]: - """ - 根据结构化 page item 生成: - - prompt - - image_path (如果是编辑模式) - - use_edit - - 规则: - 1) asset 为空:text2img,用 “json(去asset)” + “根据上述内容生成{style}风格的PPT” - 2) asset 是图片路径:img2img/edit,用 “json(去asset)” + “把这个图作为PPT的一部分...” - 3) asset 是 Table(asset_ref="Table 2"):先提取 table png(这里先占位),再走 edit - """ - asset_ref = item.get("asset_ref") or item.get("asset") or item.get("assetRef") or "" - asset_ref = str(asset_ref).strip() if asset_ref is not None else "" - # TODO 当前版本仅支持单图 asset_ref;若包含多图,仅保留第一张。 - asset_ref = _normalize_single_asset_ref(asset_ref) - - prompt_dict = dict(item) - for k in ["asset_ref", "asset", "assetRef", "asset_type", "type"]: - if k in prompt_dict: - prompt_dict.pop(k, None) - - base = _serialize_prompt_dict(prompt_dict) - - if not asset_ref: - prompt = f"{base}\n\n根据上述内容。生成{style}风格的 PPT 图像, \n 使用语言:{state.request.language}" - return prompt, None, False - - # table 走占位提取 - if _is_table_asset(asset_ref): - # 优先使用 item 自己带的表格图(如果调用方已经生成过) - table_img_path = item.get("table_img_path") or item.get("table_png_path") or "" - table_img_path = str(table_img_path).strip() - - # 若没有,则调用 table_extractor agent:生成 html->png,并写入 state.table_img_path - if not table_img_path: - state.asset_ref = asset_ref - agent = create_react_agent( - name="table_extractor", - temperature=0.1, - max_retries=6, - parser_type="json", - ) - state = await agent.execute(state=state) - - table_img_path = str(getattr(state, "table_img_path", "") or "").strip() - log.critical(f'[table_img_path 表格图像路径]: {table_img_path}') - - if not table_img_path: - raise ValueError(f"[paper2ppt] 表格提取失败,未得到 table_img_path。asset_ref={asset_ref}") - - image_path = _resolve_asset_path(table_img_path, state) - # 如果表格图像不存在,则退化为 text2img:不走编辑,返回 use_edit=False - if not image_path or not os.path.exists(image_path): - log.error(f"[paper2ppt] 表格图像文件不存在: {image_path!r} (asset_ref={asset_ref})") - prompt = f"{base}\n\n根据上述内容生成{style}风格的 PPT 图像, \n 使用语言:{state.request.language}" - return prompt, None, False - - prompt = f"{base}\n\n根据上述内容绘制ppt,把这个图作为PPT的一部分。生成{style}风格的PPT. \n 使用语言:{state.request.language} !!!" - return prompt, image_path, True - - # 默认:当作图片路径,走编辑 - image_path = _resolve_asset_path(asset_ref, state) - # 如果图片不存在,则退化为 text2img:不走编辑,返回 use_edit=False - if not image_path or not os.path.exists(image_path): - log.error(f"[paper2ppt] 图片文件不存在: {image_path!r} (asset_ref={asset_ref})") - prompt = f"{base}\n\n根据上述内容生成{style}风格的 PPT 图像, \n 使用语言:{state.request.language}" - return prompt, None, False - - prompt = f"{base}\n\n根据上述内容绘制ppt,把这个图作为PPT的一部分。生成{style}风格的PPT. \n 使用语言:{state.request.language} !!!" - return prompt, image_path, True - - -def _resolve_asset_path(asset_ref: str, state: Paper2FigureState) -> str: - """ - 根据 state 解析 asset 引用为绝对路径。 - - 规则: - - 为空直接返回 ""; - - 绝对路径或以 ~ 开头:直接通过 _abs_path 规范化; - - 相对路径: - * 优先挂在 state.mineru_root(MinerU 输出根目录)下; - * 否则挂在 state.result_path 下; - * 再否则退化为当前工作目录下的相对路径解析(_abs_path)。 - """ - if not asset_ref: - return "" - s = str(asset_ref).strip() - if not s: - return "" - - p = Path(s) - - # 已经是绝对路径,或者显式使用家目录 - if p.is_absolute() or s.startswith("~"): - return _abs_path(s) - - base_dir = getattr(state, "mineru_root", None) or getattr(state, "result_path", None) - log.critical(f'[base_dir _resolve_asset_path]: {base_dir}') - - if base_dir: - try: - return str((Path(base_dir) / p).resolve()) - except Exception: - return _abs_path(s) - - return _abs_path(s) - - -def _extract_image_path_from_pagecontent_item(item: Any) -> Optional[str]: - """ - 支持 pagecontent 直接是图片路径的几种形态: - - "/abs/xxx.png" - - {"ppt_img_path": "/abs/xxx.png"} - - {"img_path": "/abs/xxx.png"} - - {"path": "/abs/xxx.png"} - """ - if not item: - return None - if isinstance(item, str): - return item.strip() - if isinstance(item, dict): - for k in ["ppt_img_path", "img_path", "path", "image_path"]: - v = item.get(k) - if v: - return str(v).strip() - return None - - -@register("paper2ppt_parallel") -def create_paper2ppt_parallel_graph() -> GenericGraphBuilder: # noqa: N802 - """ - Workflow factory: dfa run --wf paper2ppt_parallel - - 功能: - - 并行版本:并发调用 AI 接口生成/编辑所有 PPT 页面 - - 若 state.gen_down == False:批量生成/编辑每页 PPT 图,保存到统一目录 - - 若 state.gen_down == True:按 0-based edit_page_num 对已有页面图做二次编辑(edit_page_prompt) - """ - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - def _start_(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - state.pagecontent = state.pagecontent or [] - state.generated_pages = state.generated_pages or [] - # 兼容:有些调用方把 style 放 state.style,而不是 request.style - if not getattr(state.request, "style", None) and getattr(state, "style", None): - state.request.style = getattr(state, "style") - return state - - def _route(state: Paper2FigureState) -> str: - # 如果是 all_edited_down,说明用户只想打包下载,不需要生成或编辑,直接去导出 - if getattr(state.request, "all_edited_down", False): - return "export_ppt_assets" - - # gen_down == False: 第一次批量生成 - if not getattr(state, "gen_down", False): - return "generate_pages" - # gen_down == True: 进入按页编辑 - return "edit_single_page" - - async def generate_pages(state: Paper2FigureState) -> Paper2FigureState: - """ - 批量生成/编辑页面图(并行版本): - - pagecontent 由上游 kb_page_content 的 LLM 大纲生成,无人工确认。 - - 按 asset 规则决定 text2img / img2img;或直接为图片路径列表时逐页编辑。 - 并发调用 asyncio.gather 处理所有页面。 - """ - import asyncio - page_items = state.pagecontent or [] - log.info("[paper2ppt] 使用 pagecontent 生图,共 %s 页(大纲由 LLM 生成,无人工确认)", len(page_items)) - - async def _call_image_api_with_retry(coro_factory, retries: int = 3, delay: float = 1.0) -> bool: - """ - 对图像生成/编辑进行最多 retries 次重试。 - - 成功:返回 True - - 多次失败:返回 False(由上层决定如何处理当前页) - """ - last_err: Optional[Exception] = None - for attempt in range(1, retries + 1): - try: - await coro_factory() - return True - except Exception as e: # noqa: BLE001 - last_err = e - log.error(f"[paper2ppt] image gen failed attempt {attempt}/{retries}: {e}") - if attempt < retries: - try: - await asyncio.sleep(delay) - except Exception: - # sleep 失败不影响后续重试 - pass - log.error(f"[paper2ppt] image gen failed after {retries} attempts, skip this page. last_err={last_err}") - return False - - result_root = Path(_ensure_result_path(state)) - img_dir = result_root / "ppt_pages" - img_dir.mkdir(parents=True, exist_ok=True) - - style = getattr(state.request, "style", None) or "kartoon" - aspect_ratio = getattr(state, "aspect_ratio", None) or "16:9" - - # 清空旧数据(避免重复执行堆积) - state.generated_pages = [] - - # 定义单个页面处理任务 - async def _process_single_page(idx: int, item: Any) -> Dict[str, Any]: - """ - 处理单个页面:返回生成的 result item (dict)。 - 如果失败,result item 中的 generated_img_path 为 None。 - """ - save_path = str((img_dir / f"page_{idx:03d}.png").resolve()) - - # Case B: pagecontent 本身就是图片路径 - direct_img_path = _extract_image_path_from_pagecontent_item(item) - is_direct_image_list = bool(direct_img_path) and ( - isinstance(item, str) - or (isinstance(item, dict) and set(item.keys()).intersection({"ppt_img_path", "img_path", "path", "image_path"})) - ) - - if is_direct_image_list and (not isinstance(item, dict) or ("title" not in item and "layout_description" not in item)): - # 规则 2:只做风格化编辑 - image_path = _abs_path(direct_img_path) - # 强化提示词,确保模型进行重绘而不是原图输出 - prompt = ( - f"Please beautify and re-design this PowerPoint slide image. " - f"Keep all the original text and structure, but completely transform it into a professional, " - f"visually stunning presentation slide in {style} style. " - f"Make sure the colors, layout, and background are improved significantly." - ) - log.info(f"[paper2ppt] page={idx} direct image edit: image={image_path}, save={save_path}") - - log.critical(f'[强化提示词,确保模型进行重绘而不是原图输出]: {prompt}') - - ok = await _call_image_api_with_retry( - lambda: generate_or_edit_and_save_image_async( - prompt=prompt, - save_path=save_path, - aspect_ratio=aspect_ratio, - api_url=state.request.chat_api_url, - api_key=state.request.chat_api_key or os.getenv("DF_API_KEY") , - model=state.request.gen_fig_model, - image_path=image_path, - use_edit=True, - ) - ) - if not ok: - # 记录失败信息 - return { - "source_img_path": image_path, - "generated_img_path": None, - "page_idx": idx, - "mode": "edit_direct_image_failed", - "style": style, - } - - return { - "source_img_path": image_path, - "generated_img_path": save_path, - "page_idx": idx, - "mode": "edit_direct_image", - "style": style, - } - - # Case A: 结构化页面 - if not isinstance(item, dict): - log.warning(f"[paper2ppt] page={idx} 非 dict 且非 image path,跳过。item={item}") - return { - "page_idx": idx, - "mode": "invalid_item_skipped", - "generated_img_path": None, - } - - try: - # 注意:_make_prompt_for_structured_page 可能是 async 的,因为它可能调用 table_extractor agent - prompt, image_path, use_edit = await _make_prompt_for_structured_page(item, style=style, state=state) - except Exception as e: # noqa: BLE001 - log.error(f"[paper2ppt] page={idx} prompt 构造失败: {e}") - failed_item = dict(item) - failed_item.update({ - "generated_img_path": None, - "page_idx": idx, - "mode": "prompt_build_failed", - "style": style, - "error": str(e), - }) - return failed_item - - title = (item.get("title") or "")[:40] - log.info( - "[paper2ppt] page=%s/%s title=%s use_edit=%s save=%s", - idx + 1, len(state.pagecontent or []), title, use_edit, save_path, - ) - - ok = await _call_image_api_with_retry( - lambda: generate_or_edit_and_save_image_async( - prompt=prompt, - save_path=save_path, - aspect_ratio=aspect_ratio, - api_url=state.request.chat_api_url, - api_key=state.request.chat_api_key or os.getenv("DF_API_KEY") , - model=state.request.gen_fig_model, - image_path=image_path, - use_edit=use_edit, - ) - ) - if not ok: - failed_item = dict(item) - failed_item.update({ - "generated_img_path": None, - "page_idx": idx, - "mode": "generate_failed" if not use_edit else "edit_failed", - "style": style, - }) - return failed_item - - # 成功 - out_item = dict(item) - out_item.update({ - "generated_img_path": save_path, - "page_idx": idx, - "mode": "edit" if use_edit else "generate", - "style": style, - }) - return out_item - - # ----------------------------------------------------------- - # 并发执行逻辑 - # ----------------------------------------------------------- - tasks = [] - for idx, item in enumerate(page_items): - tasks.append(_process_single_page(idx, item)) - log.info("[paper2ppt] 并发生图开始,共 %s 页", len(tasks)) - start_time = time.time() - - # 使用 gather 并发执行所有任务 - results = await asyncio.gather(*tasks, return_exceptions=True) - - cost_time = time.time() - start_time - log.info("[paper2ppt] 并发生图完成,耗时 %.2fs", cost_time) - - # 整理结果 - new_pagecontent: List[Dict[str, Any]] = [] - state.generated_pages = [] - - for i, res in enumerate(results): - if isinstance(res, Exception): - # 理论上 _process_single_page 内部捕获了大部分异常,这里是防漏 - log.error(f"[paper2ppt_parallel] page {i} unhandled exception: {res}") - # 构造一个失败项 - failed_item = dict(page_items[i]) if isinstance(page_items[i], dict) else {"raw_item": str(page_items[i])} - failed_item.update({ - "generated_img_path": None, - "page_idx": i, - "mode": "unhandled_exception", - "error": str(res) - }) - new_pagecontent.append(failed_item) - else: - # res 是 dict - res_dict = res # type: ignore - new_pagecontent.append(res_dict) - gen_path = res_dict.get("generated_img_path") - - if gen_path: - state.generated_pages.append(gen_path) - else: - # 占位,防止索引错位 - state.generated_pages.append("") - - state.pagecontent = new_pagecontent - return state - - async def edit_single_page(state: Paper2FigureState) -> Paper2FigureState: - """ - gen_down == True 时的路径: - 通过 edit_page_num(0-based) + edit_page_prompt 对已经生成好的某一页做二次编辑。 - - 当前策略(B1): - - 不再生成 *_edit_*.png 新文件; - - 直接覆盖原来的 page_{idx:03d}.png,保证导出时每页只有一张图。 - """ - idx = int(getattr(state, "edit_page_num", -1)) - prompt = (getattr(state, "edit_page_prompt", "") or "").strip() - if idx < 0: - raise ValueError("[paper2ppt] edit_page_num 必须是 0-based 且 >=0") - - # 取出原图路径:优先 generated_pages,其次 pagecontent[i].ppt_img_path - old_path: Optional[str] = None - - # Debug log - log.info(f"[paper2ppt] edit_single_page: idx={idx}") - # log.info(f"[paper2ppt] generated_pages={getattr(state, 'generated_pages', None)}") - - if getattr(state, "generated_pages", None) and idx < len(state.generated_pages): - old_path = state.generated_pages[idx] - log.info(f"[paper2ppt] got old_path from generated_pages: {old_path}") - - if not old_path and idx < len(state.pagecontent or []): - it = state.pagecontent[idx] - if isinstance(it, dict): - old_path = it.get("generated_img_path") or it.get("ppt_img_path") or it.get("img_path") - log.info(f"[paper2ppt] got old_path from pagecontent: {old_path}") - - if not old_path: - raise ValueError(f"[paper2ppt] 找不到要编辑的页图路径: idx={idx}") - - old_path = _abs_path(old_path) - - result_root = Path(_ensure_result_path(state)) - img_dir = result_root / "ppt_pages" - img_dir.mkdir(parents=True, exist_ok=True) - - # B1 策略:编辑时直接覆盖原始 page_{idx:03d}.png,避免 *_edit_*.png 累积 - save_path = str((img_dir / f"page_{idx:03d}.png").resolve()) - aspect_ratio = getattr(state, "aspect_ratio", None) or "16:9" - style = getattr(state.request, "style", None) or "kartoon" - - # 强化提示词,确保模型进行重绘 - if prompt: - # 用户提供了具体修改意见 - full_prompt = ( - f"Beautify this PowerPoint slide based on this instruction: '{prompt}'. " - # f"Transform the existing design into a high-end, professional {style} style presentation. " - f"Enhance the visual aesthetics, layout, and background while preserving the core message." - ) - else: - # 用户未提供具体修改意见,仅仅请求重新生成/美化 - full_prompt = ( - f"Beautify and re-design this PowerPoint slide. " - f"Transform the existing design into a high-end, professional {style} style presentation. " - f"Enhance the visual aesthetics, layout, and background while preserving the core message." - ) - - log.info(f"[paper2ppt] edit_single_page idx={idx} old={old_path} save={save_path}") - log.critical(f'[full_prompt] {full_prompt}') - - await generate_or_edit_and_save_image_async( - prompt=full_prompt, - save_path=save_path, - aspect_ratio=aspect_ratio, - api_url=state.request.chat_api_url, - api_key=state.request.chat_api_key or os.getenv("DF_API_KEY") , - model=state.request.gen_fig_model, - image_path=old_path, - use_edit=True, - ) - - # 回写路径 - if getattr(state, "generated_pages", None) and idx < len(state.generated_pages): - state.generated_pages[idx] = save_path - if idx < len(state.pagecontent or []): - it = state.pagecontent[idx] - if isinstance(it, dict): - it["generated_img_path"] = save_path - it["edit_prompt"] = prompt - it["mode"] = "edit_again" - - # 清理编辑请求(可选) - state.edit_page_prompt = "" - state.edit_page_num = -1 - return state - - async def export_ppt_assets(state: Paper2FigureState) -> Paper2FigureState: - """ - 最终导出节点: - - 使用 ppt_tool.convert_images_dir_to_pdf_and_ppt_api(带 API inpainting 支持) - 将 result_path/ppt_pages 中的页面图导出为 PDF 和可编辑 PPTX。 - - 注意: - - gen_down == False(首次生成):始终导出; - - gen_down == True(编辑模式):只有在 request.all_edited_down == True 时才导出, - 否则直接跳过该节点。 - """ - # 若处于编辑模式且未标记全部编辑完成,则跳过导出 - if getattr(state, "gen_down", False): - all_done = getattr(getattr(state, "request", None), "all_edited_down", False) - if not all_done: - log.info("[paper2ppt] export_ppt_assets skipped: gen_down=True & all_edited_down is False") - return state - - result_root = Path(_ensure_result_path(state)) - img_dir = result_root / "ppt_pages" - - if not img_dir.exists(): - raise ValueError(f"[paper2ppt] export_ppt_assets: image dir not found: {img_dir}") - - pdf_path = result_root / "paper2ppt.pdf" - pptx_path = result_root / "paper2ppt_editable.pptx" - - log.info( - f"[paper2ppt] export_ppt_assets: images_dir={img_dir}, " - f"pdf={pdf_path}, pptx={pptx_path}" - ) - - # 使用新的 API 版本函数(带 inpainting 支持) - out = await convert_images_dir_to_pdf_and_ppt_api( - input_dir=str(img_dir), - output_pdf_path=str(pdf_path), - output_pptx_path=None, - api_url=state.request.chat_api_url, - api_key=state.request.chat_api_key or os.getenv("DF_API_KEY") , - model=state.request.gen_fig_model, - use_api_inpaint=False, # 启用 API inpainting - ) - - # 可选:把导出结果路径挂到 state 上,方便后续使用 - setattr(state, "ppt_pdf_path", out.get("pdf") or str(pdf_path)) - setattr(state, "ppt_pptx_path", None) - # 标记首次生成已完成,后续可进入编辑模式 - state.gen_down = True - - return state - - nodes = { - "_start_": _start_, - "generate_pages": generate_pages, - "edit_single_page": edit_single_page, - "export_ppt_assets": export_ppt_assets, - "_end_": lambda state: state, - } - - edges = [ - ("generate_pages", "export_ppt_assets"), - ("edit_single_page", "export_ppt_assets"), - ("export_ppt_assets", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges).add_conditional_edge("_start_", _route) - return builder diff --git a/dataflow_agent/workflow/wf_paper2technical.py b/dataflow_agent/workflow/wf_paper2technical.py deleted file mode 100644 index 41714d8..0000000 --- a/dataflow_agent/workflow/wf_paper2technical.py +++ /dev/null @@ -1,632 +0,0 @@ -""" -paper2technical workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-07 23:36:51 - -1. 在 **TOOLS** 区域定义需要暴露给 Prompt 的前置工具 -2. 在 **NODES** 区域实现异步节点函数 (await-able) -3. 在 **EDGES** 区域声明有向边 -4. 最后返回 builder.compile() 或 GenericGraphBuilder -""" - -from __future__ import annotations -import json -import time -from pathlib import Path -import re - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.workflow.registry import register -from dataflow_agent.agentroles import create_simple_agent -from dataflow_agent.toolkits.tool_manager import get_tool_manager -from dataflow_agent.toolkits.multimodaltool.bg_tool import ( - local_tool_for_svg_render, - local_tool_for_raster_to_svg, -) -from dataflow_agent.utils import get_project_root -from dataflow_agent.logger import get_logger -log = get_logger(__name__) - - -def _ensure_result_path(state: Paper2FigureState) -> str: - """ - 统一本次 workflow 的根输出目录: - - 如果 state.result_path 已存在(通常由调用方传入,形如 时间戳+编码),直接使用; - - 否则:使用 get_project_root() / "outputs" / "paper2tec" / , - 并回写到 state.result_path,确保后续节点共享同一目录,避免数据串台。 - """ - raw = getattr(state, "result_path", None) - if raw: - return raw - - root = get_project_root() - ts = int(time.time()) - base_dir = (root / "outputs" / "paper2tec" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - - -def _extract_svg_from_react_md(md_path: Path) -> str: - """ - 从 React 组件的 .md 文件中提取纯 SVG 代码。 - - 输入的 .md 文件包含 React 组件代码,其中嵌入了 SVG。 - 此函数提取 ... 部分,并将 React 语法转换为纯 SVG。 - """ - if not md_path.exists(): - log.warning(f"模板文件不存在: {md_path}") - return "" - - try: - content = md_path.read_text(encoding="utf-8") - - # 查找 结束标签 - svg_end = content.rfind("") - if svg_end == -1: - log.warning(f"未在文件中找到 标签: {md_path}") - return "" - - svg_end += len("") - svg_code = content[svg_start:svg_end] - - # 清理 React 特有语法: - # 1. 将 className 替换为 class - svg_code = svg_code.replace('className="', 'class="') - - # 2. 将 JSX 驼峰命名属性转换为 SVG 连字符命名 - # 注意:某些属性在 SVG 中必须保持驼峰命名(如 markerWidth, markerHeight, viewBox 等) - jsx_to_svg_attrs = { - 'strokeWidth': 'stroke-width', - 'strokeDasharray': 'stroke-dasharray', - 'strokeLinecap': 'stroke-linecap', - 'strokeLinejoin': 'stroke-linejoin', - 'strokeOpacity': 'stroke-opacity', - 'fillOpacity': 'fill-opacity', - 'textAnchor': 'text-anchor', - 'fontWeight': 'font-weight', - 'fontSize': 'font-size', - 'fontFamily': 'font-family', - # markerWidth 和 markerHeight 应该保持驼峰命名,不转换 - 'markerEnd': 'marker-end', - 'markerStart': 'marker-start', - 'markerMid': 'marker-mid', - 'clipPath': 'clip-path', - } - for jsx_attr, svg_attr in jsx_to_svg_attrs.items(): - svg_code = svg_code.replace(f'{jsx_attr}=', f'{svg_attr}=') - - # 3. 将 {colors.xxx} 这样的变量引用替换为实际颜色值 - colors_match = re.search(r'const colors = \{([^}]+)\}', content, re.DOTALL) - if colors_match: - colors_def = colors_match.group(1) - # 解析颜色定义 - color_map = {} - for line in colors_def.split('\n'): - match = re.search(r'(\w+):\s*"([^"]+)"', line) - if match: - color_map[match.group(1)] = match.group(2) - - # 替换 {colors.xxx} 为实际颜色值 - for key, value in color_map.items(): - svg_code = svg_code.replace(f'{{colors.{key}}}', value) - - # 4. 移除 React 注释 {/* ... */} - svg_code = re.sub(r'\{/\*.*?\*/\}', '', svg_code, flags=re.DOTALL) - - # 5. 转义 XML 特殊字符(在文本内容中) - # 注意:只转义 text 元素内的 &,不转义已经是实体引用的部分 - # 使用负向前瞻确保不会重复转义已经转义的内容 - svg_code = re.sub(r'&(?!amp;|lt;|gt;|quot;|apos;|#)', '&', svg_code) - - return svg_code.strip() - - except Exception as e: - log.error(f"提取 SVG 代码失败: {e}") - return "" - - -def _get_template_svg_code(state: Paper2FigureState, use_color: bool = False) -> str: - """ - 根据语言和配色选择合适的 SVG 模板代码。 - - - 中文灰度: dataflow_agent/workflow/resources/SVG_template_ZN_gray.md - - 中文彩色: dataflow_agent/workflow/resources/SVG_template_ZN_color.md - - 英文灰度: dataflow_agent/workflow/resources/SVG_template_EN_gray.md - - 英文彩色: dataflow_agent/workflow/resources/SVG_template_EN_color.md - - Args: - state: 工作流状态 - use_color: 是否使用彩色模板 - - 返回纯 SVG 代码字符串。 - """ - root = get_project_root() - lang = getattr(getattr(state, "request", None), "language", "EN") - - # 模板目录 - template_dir = root / "dataflow_agent" / "workflow" / "resources" - - # 根据语言和配色选择模板文件 - lang_prefix = "ZN" if lang.upper() in ["ZH", "CN", "CHINESE", "中文"] else "EN" - color_suffix = "color" if use_color else "gray" - template_file = template_dir / f"SVG_template_{lang_prefix}_{color_suffix}.md" - - svg_code = _extract_svg_from_react_md(template_file) - - if not svg_code: - log.warning(f"无法从模板文件提取 SVG 代码: {template_file}") - - return svg_code - - -def _get_palette_config(state: Paper2FigureState) -> dict | None: - """ - 根据 request.tech_route_palette 返回色卡配置;未选择则返回 None。 - """ - palette_name = getattr(getattr(state, "request", None), "tech_route_palette", "") or "" - if not palette_name: - return None - - palettes = { - "academic_blue": { - "name": "academic_blue", - "colors": ["#1F6FEB", "#60A5FA", "#A7C7FF", "#0B3D91"], - "level_colors": ["#A7C7FF", "#60A5FA", "#1F6FEB", "#0B3D91"], - "arrow_color": "#0B3D91", - "text_color": "#0B3D91", - }, - "teal_orange": { - "name": "teal_orange", - "colors": ["#0F766E", "#14B8A6", "#F59E0B", "#FB923C"], - "level_colors": ["#14B8A6", "#0F766E", "#F59E0B", "#FB923C"], - "arrow_color": "#0F766E", - "text_color": "#0F766E", - }, - "slate_rose": { - "name": "slate_rose", - "colors": ["#334155", "#64748B", "#F43F5E", "#FCA5A5"], - "level_colors": ["#64748B", "#334155", "#FCA5A5", "#F43F5E"], - "arrow_color": "#334155", - "text_color": "#334155", - }, - "indigo_amber": { - "name": "indigo_amber", - "colors": ["#4338CA", "#6366F1", "#F59E0B", "#FCD34D"], - "level_colors": ["#6366F1", "#4338CA", "#FCD34D", "#F59E0B"], - "arrow_color": "#4338CA", - "text_color": "#4338CA", - }, - } - - return palettes.get(palette_name) - - -@register("paper2technical") -def create_paper2technical_graph() -> GenericGraphBuilder: # noqa: N802 - """ - Workflow factory: dfa run --wf paper2technical - """ - # 使用 Paper2FigureState,复用其中的 paper_file / paper_idea / fig_desc 等字段, - # 这里不做图像生成和抠图,只负责"技术路线图"的 SVG 生成。 - builder = GenericGraphBuilder( - state_model=Paper2FigureState, - entry_point="_start_", # 入口统一为 _start_,再由路由函数分发 - ) - - # ---------------------------------------------------------------------- - # TOOLS (pre_tool definitions) - # ---------------------------------------------------------------------- - # 1) 提供给 paper_idea_extractor 的 PDF 内容(标题 + 前几页正文) - @builder.pre_tool("paper_content", "paper_idea_extractor") - def _get_paper_content(state: Paper2FigureState): - """ - 前置工具: 读取论文 PDF 的标题和前若干页内容,供 paper_idea_extractor 节点使用。 - - - 作用: 为大模型提供足够的上下文,让其抽取论文中的技术路线/实验流程关键信息。 - - 输出: 一个字符串,包含论文标题 + 前若干页文本。 - """ - import fitz # PyMuPDF - import PyPDF2 - - pdf_path = state.paper_file - if not pdf_path: - log.warning("paper_file 为空,无法读取 PDF 内容") - return "" - - try: - with open(pdf_path, "rb") as f: - reader = PyPDF2.PdfReader(f) - paper_title = reader.metadata.get("/Title", "Unknown Title") - except Exception: - paper_title = "Unknown Title" - - try: - doc = fitz.open(pdf_path) - except Exception as e: - log.error(f"打开 PDF 失败: {e}") - return f"The title of the paper is {paper_title}" - - text_parts: list[str] = [] - # 读取前 10 页内容,通常技术路线、整体框架会在前几页出现 - for page_idx in range(min(10, len(doc))): - page = doc.load_page(page_idx) - text_parts.append(page.get_text("text") or "") - - content = "\n".join(text_parts).strip() - final_text = ( - f"The title of the paper is {paper_title}\n\n" - f"Here are the first 10 pages of the paper:\n{content}" - ) - log.info("paper_content 提取完成") - return final_text - - @builder.pre_tool("paper_idea", "technical_route_bw_svg_generator") - def _get_bw_paper_idea(state: Paper2FigureState): - return state.paper_idea or "" - - @builder.pre_tool("template_svg_code", "technical_route_bw_svg_generator") - def _get_template_svg(state: Paper2FigureState): - """ - 前置工具: 提供 SVG 模板代码给黑白技术路线图生成器。 - - - 作用: 优先使用参考图生成的 SVG 代码;如果没有参考图,则根据语言选择合适的 SVG 模板。 - - 输出: 纯 SVG 代码字符串。 - """ - # 优先使用参考图生成的 SVG 代码 - if hasattr(state, "temp_data") and state.temp_data.get("reference_svg_code"): - log.info("[_get_template_svg] 使用参考图生成的 SVG 代码作为模板") - return state.temp_data["reference_svg_code"] - - # 否则使用默认模板 - log.info("[_get_template_svg] 使用默认 SVG 模板") - return _get_template_svg_code(state) - - @builder.pre_tool("validation_feedback", "technical_route_bw_svg_generator") - def _get_bw_feedback(state: Paper2FigureState): - return state.temp_data.get("validation_feedback", "") if hasattr(state, "temp_data") else "" - - @builder.pre_tool("validation_feedback", "technical_route_colorize_svg") - def _get_color_feedback(state: Paper2FigureState): - return state.temp_data.get("validation_feedback", "") if hasattr(state, "temp_data") else "" - - @builder.pre_tool("bw_svg_code", "technical_route_colorize_svg") - def _get_bw_svg_code(state: Paper2FigureState): - return state.figure_tec_svg_bw_content or "" - - @builder.pre_tool("palette_json", "technical_route_colorize_svg") - def _get_palette_json(state: Paper2FigureState): - return state.temp_data.get("palette_json", "") if hasattr(state, "temp_data") else "" - - @builder.pre_tool("color_template_svg", "technical_route_colorize_svg") - def _get_color_template_svg(state: Paper2FigureState): - """ - 前置工具: 提供彩色 SVG 模板代码给彩色化 agent 作为参考。 - - - 作用: 让 agent 了解彩色模板的配色风格和结构。 - - 输出: 彩色 SVG 模板代码字符串。 - """ - return _get_template_svg_code(state, use_color=True) - - @builder.pre_tool("reference_image_path", "tech_route_reference_analyzer") - def _get_reference_image_path(state: Paper2FigureState): - """ - 前置工具: 提供参考图路径给 VLM 分析器。 - """ - return state.temp_data.get("reference_image_path", "") if hasattr(state, "temp_data") else "" - - # ---------------------------------------------------------------------- - - # ============================================================== - # NODES - # ============================================================== - async def paper_idea_extractor_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 节点 1: 从 PDF 中抽取论文的核心思想 / 技术路线相关信息 - - - 只在 input_type == "PDF" 时作为入口节点被调用。 - - 基于 pre_tool("paper_content") 提供的标题 + 前若干页内容, - 调用专门的 agent(例如 paper_idea_extractor)生成摘要。 - - 该摘要用于后续技术路线图描述生成。 - - 输入: - state.paper_file : 论文 PDF 路径 - 输出: - state.paper_idea : 论文核心思想 / 技术路线要点摘要 - state.agent_results["paper_idea_extractor"] : agent 原始输出 - """ - agent = create_simple_agent("paper_idea_extractor") - state = await agent.execute(state=state) - return state - - async def reference_image_analyzer_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 节点: 使用 VLM 分析参考图,提取布局、风格、配色等信息 - - - 只在有参考图时被调用 - - 分析结果存入 state.temp_data["reference_understanding"] - """ - from dataflow_agent.agentroles.paper2any_agents.tech_route_reference_analyzer import ( - create_tech_route_reference_analyzer, - ) - - ref_img_path = state.temp_data.get("reference_image_path", "") if hasattr(state, "temp_data") else "" - if not ref_img_path: - log.warning("reference_image_analyzer_node: 无参考图路径,跳过分析") - return state - - log.info(f"[reference_image_analyzer_node] 分析参考图: {ref_img_path}") - - # 使用 VLM 模式分析参考图 - model_name = getattr(getattr(state, "request", None), "tec_vlm_desc_model", "") or "deepseek-v3.2" - agent = create_tech_route_reference_analyzer( - model_name=model_name, - temperature=0.0, - parser_type="json", - use_vlm=True, - vlm_config={"input_image": ref_img_path}, - ) - state = await agent.execute(state=state) - log.info(f"[reference_image_analyzer_node] 分析完成") - return state - - def _svg_has_cjk(text: str) -> bool: - """简单判断 SVG 中是否包含中文字符,用于日志和调试。""" - return bool(re.search(r"[\u4e00-\u9fff]", text)) - - - def _post_process_svg(svg_code: str) -> str: - """ - SVG 后处理: - 1. 确保 标签包含 xmlns="http://www.w3.org/2000/svg" 命名空间(修复浏览器显示为XML代码的问题)。 - 2. 如果包含中文字符,注入中文友好字体。 - """ - if not svg_code: - return svg_code - - # 1. 注入命名空间 - if 'xmlns="http://www.w3.org/2000/svg"' not in svg_code: - log.info("[_post_process_svg] 注入 xmlns 命名空间") - # 查找 - # 简单替换第一个 ") - if idx != -1: - style_block = f""" - - """ - svg_code = svg_code[:idx + 1] + style_block + svg_code[idx + 1:] - - return svg_code - - async def technical_route_bw_svg_generator_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 黑白技术路线图生成(使用 ReAct 模式带验证器) - """ - from dataflow_agent.agentroles import create_react_agent - - base_dir = Path(_ensure_result_path(state)) - base_dir.mkdir(parents=True, exist_ok=True) - - # 使用前端传递的模型,默认为 claude-sonnet-4-5 - model_name = getattr(getattr(state, "request", None), "gen_fig_model", "") or "claude-sonnet-4-5-20250929" - log.critical(f"[technical_route_bw_svg_generator] 使用模型: {model_name}") - - # 使用 create_react_agent 创建带验证器的 agent - agent = create_react_agent( - "technical_route_bw_svg_generator", - max_retries=3, - model_name=model_name, - temperature=0.0, - max_tokens=65536, - ) - - # 执行 agent(验证和重试由 agent 内部处理) - state = await agent.execute(state=state) - - # 获取生成的 SVG - svg_code = getattr(state, "figure_tec_svg_bw_content", None) - if not svg_code: - log.error("technical_route_bw_svg_generator_node: Agent 未返回 SVG 代码") - return state - - # SVG 后处理(注入命名空间、中文字体等) - svg_code = _post_process_svg(svg_code) - - # 保存 SVG 文件和渲染 PNG - timestamp = int(time.time()) - svg_output_path = str((base_dir / f"technical_route_bw_{timestamp}.svg").resolve()) - png_output_path = str((base_dir / f"technical_route_bw_{timestamp}.png").resolve()) - - try: - Path(svg_output_path).write_text(svg_code, encoding="utf-8") - png_path = local_tool_for_svg_render({ - "svg_code": svg_code, - "output_path": png_output_path, - }) - state.svg_bw_file_path = svg_output_path - state.svg_bw_img_path = png_path - state.svg_file_path = svg_output_path - state.svg_img_path = png_path - state.figure_tec_svg_content = svg_code - log.critical(f"[state.svg_bw_img_path]: {state.svg_bw_img_path}") - log.critical(f"[state.svg_bw_file_path]: {state.svg_bw_file_path}") - except Exception as e: - log.error(f"technical_route_bw_svg_generator_node: SVG 保存/渲染失败: {e}") - - return state - - async def technical_route_colorize_svg_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 彩色技术路线图生成(使用 ReAct 模式带验证器) - """ - from dataflow_agent.agentroles import create_react_agent - - base_dir = Path(_ensure_result_path(state)) - base_dir.mkdir(parents=True, exist_ok=True) - - palette_cfg = _get_palette_config(state) - if not palette_cfg: - return state - - if not hasattr(state, "temp_data"): - state.temp_data = {} - state.temp_data["palette_json"] = json.dumps(palette_cfg, ensure_ascii=False) - - # 使用前端传递的模型 - model_name = getattr(getattr(state, "request", None), "gen_fig_model", "") or "claude-sonnet-4-5-20250929" - log.critical(f"[technical_route_colorize_svg] 使用模型: {model_name}") - - # 使用 create_react_agent 创建带验证器的 agent - agent = create_react_agent( - "technical_route_colorize_svg", - max_retries=3, - model_name=model_name, - temperature=0.0, - max_tokens=65536, - ) - - # 执行 agent - state = await agent.execute(state=state) - - # 获取生成的彩色 SVG - svg_code = getattr(state, "figure_tec_svg_color_content", None) - if not svg_code: - log.error("technical_route_colorize_svg_node: Agent 未返回 SVG 代码") - return state - - # SVG 后处理(注入命名空间、中文字体等) - svg_code = _post_process_svg(svg_code) - - # 保存文件 - timestamp = int(time.time()) - svg_output_path = str((base_dir / f"technical_route_color_{timestamp}.svg").resolve()) - png_output_path = str((base_dir / f"technical_route_color_{timestamp}.png").resolve()) - - try: - Path(svg_output_path).write_text(svg_code, encoding="utf-8") - png_path = local_tool_for_svg_render({ - "svg_code": svg_code, - "output_path": png_output_path, - }) - state.svg_color_file_path = svg_output_path - state.svg_color_img_path = png_path - log.critical(f"[state.svg_color_img_path]: {state.svg_color_img_path}") - log.critical(f"[state.svg_color_file_path]: {state.svg_color_file_path}") - except Exception as e: - log.error(f"technical_route_colorize_svg_node: SVG 保存/渲染失败: {e}") - - return state - - # ============================================================== - # 注册 nodes / edges - # ============================================================== - - def set_entry_node(state: Paper2FigureState) -> str: - """ - 路由函数: 根据输入类型选择技术路线工作流的入口节点。 - - - input_type == "PDF" : 从 PDF 中抽取论文想法,先走 paper_idea_extractor - - input_type == "TEXT" : 检查是否有参考图 - - 有参考图: 先走 reference_image_analyzer - - 无参考图: 直接走 technical_route_bw_svg_generator - 其他值: - - 认为是不合法输入,直接结束工作流。 - """ - input_type = getattr(state.request, "input_type", "PDF") - has_ref = bool(state.temp_data.get("reference_image_path", "")) if hasattr(state, "temp_data") else False - - if input_type == "PDF": - log.critical("paper2technical: 进入 PDF 流程 (paper_idea_extractor)") - return "paper_idea_extractor" - elif input_type == "TEXT": - if has_ref: - log.critical("paper2technical: 进入 TEXT 流程 + 参考图 (reference_image_analyzer)") - return "reference_image_analyzer" - log.critical("paper2technical: 进入 TEXT 流程 (technical_route_bw_svg_generator)") - return "technical_route_bw_svg_generator" - else: - log.error(f"paper2technical: Invalid input type: {input_type}") - return "_end_" - - def _init_result_path(state: Paper2FigureState) -> Paper2FigureState: - """ - _start_ 节点:确保本次 workflow 有一个统一的 result_path 根目录。 - - 若用户已在 state.result_path 传入自定义目录,则直接使用该目录; - - 若未传入,则初始化为 get_project_root()/outputs/paper2tec/。 - """ - _ensure_result_path(state) - return state - - nodes = { - "_start_": _init_result_path, - "paper_idea_extractor": paper_idea_extractor_node, - "reference_image_analyzer": reference_image_analyzer_node, - "technical_route_bw_svg_generator": technical_route_bw_svg_generator_node, - "technical_route_colorize_svg": technical_route_colorize_svg_node, - "_end_": lambda state: state, # 终止节点 - } - - # ------------------------------------------------------------------ - # EDGES (从节点 A 指向节点 B) - # ------------------------------------------------------------------ - edges = [ - # 参考图分析后,进入 SVG 生成 - ("reference_image_analyzer", "technical_route_bw_svg_generator"), - # 生成彩色后,直接结束(不再生成 PPT) - ("technical_route_colorize_svg", "_end_"), - ] - - def _route_after_idea_extractor(state: Paper2FigureState) -> str: - """ - 路由函数: paper_idea_extractor 之后,检查是否有参考图 - - 有参考图: 先走 reference_image_analyzer - - 无参考图: 直接走 technical_route_bw_svg_generator - """ - has_ref = bool(state.temp_data.get("reference_image_path", "")) if hasattr(state, "temp_data") else False - if has_ref: - log.critical("[_route_after_idea_extractor] -> reference_image_analyzer") - return "reference_image_analyzer" - log.critical("[_route_after_idea_extractor] -> technical_route_bw_svg_generator (no ref)") - return "technical_route_bw_svg_generator" - - def _route_after_bw(state: Paper2FigureState) -> str: - palette = getattr(getattr(state, "request", None), "tech_route_palette", "") or "" - log.critical(f"[_route_after_bw] tech_route_palette: '{palette}'") - if palette: - log.critical(f"[_route_after_bw] -> technical_route_colorize_svg") - return "technical_route_colorize_svg" - # 无配色时直接结束,不再生成 PPT - log.critical(f"[_route_after_bw] -> _end_ (no palette, skip PPT)") - return "_end_" - - builder.add_nodes(nodes).add_edges(edges).add_conditional_edge("_start_", set_entry_node) - builder.add_conditional_edge("paper_idea_extractor", _route_after_idea_extractor) - builder.add_conditional_edge("technical_route_bw_svg_generator", _route_after_bw) - return builder diff --git a/dataflow_agent/workflow/wf_paper2video.py b/dataflow_agent/workflow/wf_paper2video.py deleted file mode 100644 index 2ddbd69..0000000 --- a/dataflow_agent/workflow/wf_paper2video.py +++ /dev/null @@ -1,300 +0,0 @@ -""" -paper2video workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-11-26 11:08:03 - -1. 在 **TOOLS** 区域定义需要暴露给 Prompt 的前置工具 -2. 在 **NODES** 区域实现异步节点函数 (await-able) -3. 在 **EDGES** 区域声明有向边 -4. 最后返回 builder.compile() 或 GenericGraphBuilder -""" - -from __future__ import annotations -import json -from dataclasses import Field -from pydantic import BaseModel -from dataflow_agent.state import Paper2VideoRequest, Paper2VideoState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.workflow.registry import register - -from dataflow_agent.toolkits.tool_manager import get_tool_manager -from langchain.tools import tool -from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode, tools_condition -from dataflow_agent.toolkits.p2vtool.p2v_tool import compile_tex, beamer_code_validator, get_image_paths, parse_script, transcribe_with_whisperx, inference_f5 - -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger -from pathlib import Path -from pdf2image import convert_from_path -from dataflow_agent.toolkits.multimodaltool.mineru_tool import run_mineru_pdf_extract - -log = get_logger(__name__) - -# @register("paper2video") -def create_paper2video_graph() -> GenericGraphBuilder: - """ - Workflow factory: dfa run --wf paper2video - """ - builder = GenericGraphBuilder(state_model=Paper2VideoState, - entry_point="p2v_extract_pdf") # 自行修改入口 - - # ---------------------------------------------------------------------- - # TOOLS (pre_tool definitions) - # ---------------------------------------------------------------------- - - @builder.pre_tool("pdf_markdown", "p2v_extract_pdf") - def get_markdown(state: Paper2VideoState): - import subprocess - paper_pdf_path = Path(state.request.get("paper_pdf_path", "")) - # paper_pdf_path = Path("/mnt/DataFlow/lz/proj/agentgroup/ligang/DataFlow-Agent/data/2510.05096v2.pdf") - if not paper_pdf_path.exists(): - log.error(f"PDF 文件不存在: {paper_pdf_path}") - return "" - paper_pdf_dir = paper_pdf_path.with_suffix('').parent - if not paper_pdf_path.with_suffix('').exists(): - #fixme: 这里需要修改为部署机器上的mineru - run_mineru_pdf_extract(str(paper_pdf_path), str(paper_pdf_dir), "modelscope") - - paper_base_path = paper_pdf_path.with_suffix('').expanduser().resolve() - paper_output_dir = paper_base_path - markdown_path = paper_output_dir / "auto" / f"{paper_base_path.name}.md" - if not markdown_path.exists(): - log.error(f"Markdown 文件不存在: {str(markdown_path)}") - return "" - try: - markdown_content = markdown_path.read_text(encoding='utf-8') - return markdown_content - except Exception as e: - log.error(f'读取 markdown 文件内容失败:{markdown_path}. 错误:{e}') - return "" - - @builder.pre_tool("pdf_images_working_dir", "p2v_extract_pdf") - def get_images_relative_path(state: Paper2VideoState): - paper_pdf_path = Path(state.request.get("paper_pdf_path", "")) - if not paper_pdf_path.exists(): - log.error(f"PDF 文件不存在: {paper_pdf_path}") - return "" - paper_base_path = paper_pdf_path.with_suffix('').expanduser().resolve() - paper_output_dir = paper_base_path - images_dir = paper_output_dir/"auto" - if not images_dir.exists(): - log.error(f"没有生成对应的图片,MinerU 识别图像失败:{images_dir}") - return "" - return str(images_dir) - - @builder.pre_tool("output_language", "p2v_extract_pdf") - def get_language(state: Paper2VideoState): - language_map = { - 'en': "English", - 'zh': "Chinese", - } - language = state.request.language - return language_map.get(language, "English") - - @builder.pre_tool("is_beamer_wrong", "p2v_beamer_code_debug") - def get_is_code_wrong(state: Paper2VideoState): - return state.is_beamer_wrong - - @builder.pre_tool("is_beamer_warning", "p2v_beamer_code_debug") - def get_is_code_warning(state: Paper2VideoState): - return state.is_beamer_warning - - @builder.pre_tool("code_debug_result", "p2v_beamer_code_debug") - def get_compile_result(state: Paper2VideoState): - return state.code_debug_result - - @builder.pre_tool("beamer_code", "p2v_beamer_code_debug") - def get_beamer_code(state: Paper2VideoState): - beamer_code_path = state.beamer_code_path - beamer_code = Path(beamer_code_path).read_text(encoding='utf-8') - return beamer_code - - @builder.pre_tool("set_subtitle_and_cursor_path", "p2v_subtitle_and_cursor") - def set_subtitle_and_cursor_path(state: Paper2VideoState): - # 因为是循环调用VLM,所以这里就只调用一次 - if state.subtitle_and_cursor_path != "" and state.slide_img_dir != "": - return None - '''处理好slide_img,并且处理好路径,同时将最后输出文档的地址写好''' - paper_pdf_path = Path(state.request.get("paper_pdf_path", "")) - if not paper_pdf_path.exists(): - log.error(f"PDF 文件不存在: {paper_pdf_path}") - return "" - paper_base_path = paper_pdf_path.with_suffix('').expanduser().resolve() - paper_output_dir = paper_base_path - subtitle_and_cursor_path = paper_output_dir/"subtitle_w_cursor.txt" - state.subtitle_and_cursor_path = str(subtitle_and_cursor_path) - - slide_img_dir = paper_output_dir/"slide_imgs" - slide_img_dir.mkdir(parents=True, exist_ok=True) - slide_imgs = convert_from_path(state.ppt_path) - for i, img in enumerate(slide_imgs): - img_path = slide_img_dir / f"slide_{i+1:03d}.png" - img.save(img_path, 'PNG') - state.slide_img_dir = str(slide_img_dir) - return None - - # 后置工具就是让agent选择的工具,可以定制多个; - # class ModuleListInput(BaseModel): - # #这里要写好工具的描述,agent会根据实际上下文输入参数: - # module_list: list = Field( - # description="List of dotted-path python modules or file paths" - # ) - # @builder.post_tool("step2") - # @tool(args_schema=ModuleListInput) - # def _post_tool1(module_list): - # return func(module_list) - - # ---------------------------------------------------------------------- - - # ============================================================== - # NODES - # ============================================================== - async def extract_pdf_node(state: Paper2VideoState) -> Paper2VideoState: - from dataflow_agent.agentroles import create_vlm_agent - log.info("开始执行extract_pdf_node节点") - agent = create_vlm_agent( - name="p2v_extract_pdf", - vlm_mode="understanding", # 视觉模式: 'understanding', 'generation', 'edit' - image_detail="high", # 图像细节: 'low', 'high', 'auto' - model_name="gpt-4o-2024-11-20", # 视觉模型 - temperature=0.1, - max_image_size=(2048, 2048), # 最大图像尺寸 - - # additional_params={}, # 额外VLM参数,可以存放图片用法为:"input_image": image_path - ) - - state = await agent.execute(state=state) - - # 可选:处理执行结果 - # agent_result = state.agent_results.get(agent.role_name, {}) - # log.info(f"Agent {agent.role_name} 执行结果: {agent_result}") - - return state - - def compile_beamer_node(state: Paper2VideoState) -> Paper2VideoState: - log.info(f"开始执行compile_beamer_node") - beamer_code_path = state.beamer_code_path - state.is_beamer_wrong, state.is_beamer_warning, state.code_debug_result = compile_tex(beamer_code_path) - if not state.is_beamer_warning: - log.info(f"Beamer 代码编译成功,无需调试") - state.ppt_path = state.beamer_code_path.replace(".tex", ".pdf") - return state - - async def beamer_code_debug_node(state: Paper2VideoState) -> Paper2VideoState: - from dataflow_agent.agentroles import create_react_agent - log.info(f"开始执行 p2v_beamer_code_debug node节点") - agent = create_react_agent( - name="p2v_beamer_code_debug", - model_name="gpt-4o-2024-11-20", - max_retries=10, - validators=[beamer_code_validator], - ) - state = await agent.execute(state) - return state - - async def subtitle_and_cursor(state: Paper2VideoState) -> Paper2VideoState: - log.info(f"开始执行 p2v_subtitle_and_cursor node节点") - from dataflow_agent.agentroles import create_vlm_agent - - slide_img_dir = state.slide_img_dir - slide_image_path_list = get_image_paths(slide_img_dir) - image_paths = '\n'.join(slide_image_path_list) - log.info(f"获得了slide_image from {slide_img_dir}, the total images are {len(slide_image_path_list)}, the images path are {image_paths}") - for img_path in slide_image_path_list: - agent = create_vlm_agent( - name="p2v_subtitle_and_cursor", - vlm_mode="understanding", # 视觉模式: 'understanding', 'generation', 'edit' - image_detail="high", # 图像细节: 'low', 'high', 'auto' - model_name="gpt-4o-2024-11-20", # 视觉模型 - temperature=0.1, - max_image_size=(2048, 2048), # 最大图像尺寸 - - additional_params={ - "input_image": img_path, - }, # 额外VLM参数,可以存放图片用法为:"input_image": image_path - ) - state = await agent.execute(state=state) - subtitle_and_cursor_info = "\n###\n".join(state.subtitle_and_cursor) - log.info(f"获取了完整的 Subtitle and Cursor 信息:\n {subtitle_and_cursor_info}") - subtitle_and_cursor_path = state.subtitle_and_cursor_path - log.info(f"内容将写入到文件地址 {subtitle_and_cursor_path}中......") - Path(subtitle_and_cursor_path).write_text(subtitle_and_cursor_info, encoding='utf-8') - return state - - def generate_speech(state: Paper2VideoState): - # 先完成pre-tool的工作 - import os - log.info(f"开始执行 p2v_generate_speech node节点") - subtitle_and_cursor_path = state.subtitle_and_cursor_path - speech_save_dir = state.speech_save_dir - os.makedirs(speech_save_dir, exist_ok=True) - ref_audio_path = state.request.ref_audio_path - - # 1、拿到subtitle的文件,并且读出其中的内容,并解析 - raw_subtitle_and_cursor_content = Path(subtitle_and_cursor_path).read_text(encoding='utf-8') - log.info(f"获取到字幕内容:\n{raw_subtitle_and_cursor_content}") - parsed_subtitle_w_cursor = parse_script(raw_subtitle_and_cursor_content) - - # 2、不同的slide分别进行处理 - for slide_idx in range(len(parsed_subtitle_w_cursor)): - speech_with_cursor = parsed_subtitle_w_cursor[slide_idx] - subtitle = "" - for _, (prompt, cursor_prompt) in enumerate(speech_with_cursor): - if len(subtitle) == 0: subtitle = prompt - else: subtitle = subtitle + "\n\n\n" + prompt - speech_result_path = os.path.join(speech_save_dir, "{}.wav".format(str(slide_idx))) - - # 3、将每个slide的字幕内容转换为音频,并保存到指定的目录 - if ref_text is None: ref_text = transcribe_with_whisperx(ref_audio_path) - inference_f5(subtitle, speech_result_path, ref_audio_path, ref_text) - - async def compile_beamer_condition(state: Paper2VideoState): - # todo: 暂时先这样判断 - if state.is_beamer_warning: - return "p2v_beamer_code_debug" - else: - return "__end__" - - - async def pdf2ppt_node(state: Paper2VideoState) -> Paper2VideoState: - - log.info(f"开始执行 pdf2ppt node节点") - from dataflow_agent.agentroles import create_simple_agent - # agent = create_simple_agent( - # name="" - # ) - - - return state - - # ============================================================== - # 注册 nodes / edges - # ============================================================== - nodes = { - "p2v_extract_pdf": extract_pdf_node, - "compile_beamer": compile_beamer_node, - "p2v_beamer_code_debug": beamer_code_debug_node, - "p2v_subtitle_and_cursor": subtitle_and_cursor, - "p2v_generate_speech": generate_speech, - "pdf2ppt": pdf2ppt_node, - '_end_': lambda state: state, # 终止节点 - } - - # ------------------------------------------------------------------ - # EDGES (从节点 A 指向节点 B) - # ------------------------------------------------------------------ - edges = [ - ("p2v_extract_pdf", "compile_beamer"), - ("p2v_beamer_code_debug", "__end__") - ] - - builder.add_nodes(nodes).add_edges(edges).add_conditional_edge("compile_beamer", compile_beamer_condition) - return builder - -if __name__ == "__main__": - import asyncio - graph_builder = create_paper2video_graph().build() - - p2v_state = Paper2VideoState(request=Paper2VideoRequest(chat_api_url="http://123.129.219.111:3000/v1")) - out = asyncio.run(graph_builder.ainvoke(p2v_state)) \ No newline at end of file diff --git a/dataflow_agent/workflow/wf_pdf2ppt_optimized.py b/dataflow_agent/workflow/wf_pdf2ppt_optimized.py deleted file mode 100644 index 2aaf1bc..0000000 --- a/dataflow_agent/workflow/wf_pdf2ppt_optimized.py +++ /dev/null @@ -1,1082 +0,0 @@ -""" -pdf2ppt_optimized workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Optimized version of pdf2ppt_parallel: -1. **Text Extraction**: Uses MinerU blocks for both layout and text content, relying on its layout analysis - to ensure correct column separation. -2. **Dynamic Slide Sizing**: Sets PPT slide dimensions to match input image resolution exactly. -3. **Smart Text Styling**: - - Samples text color from the original image. - - Auto-centers titles. - - Calculates optimal font size. -""" - -from __future__ import annotations -import os -import asyncio -from pathlib import Path -from typing import List, Dict, Any, Optional -from collections import Counter -import copy -import time -import math - -import cv2 -import numpy as np -import fitz # PyMuPDF -import yaml -from PIL import Image - -from dataflow_agent.workflow.registry import register -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.utils import get_project_root, pixels_to_inches, calculate_font_size - -# Tools -from dataflow_agent.toolkits.multimodaltool.sam_tool import segment_layout_boxes, segment_layout_boxes_server, free_sam_model -from dataflow_agent.toolkits.multimodaltool.bg_tool import local_tool_for_bg_remove, free_bg_rm_model -from dataflow_agent.toolkits.multimodaltool.mineru_tool import recursive_mineru_layout -from dataflow_agent.toolkits.multimodaltool.req_img import generate_or_edit_and_save_image_async -from dataflow_agent.toolkits.multimodaltool import ppt_tool - -from pptx import Presentation -from pptx.util import Inches, Pt -from pptx.dml.color import RGBColor -from pptx.enum.text import PP_ALIGN, MSO_AUTO_SIZE - -from dataflow_agent.toolkits.multimodaltool.ppt_text_fit import DEFAULT_FITTER, TextFitStyle - -log = get_logger(__name__) - -# Load configuration from yaml -def load_server_config(): - root = get_project_root() - config_path = root / "conf" / "model_servers.yaml" - if not config_path.exists(): - log.warning(f"Config file not found at {config_path}, using defaults.") - return {} - try: - with open(config_path, "r") as f: - return yaml.safe_load(f) or {} - except Exception as e: - log.error(f"Failed to load config: {e}") - return {} - -SERVER_CONFIG = load_server_config() - -def _resize_image_for_ppt(img_path: str, max_w: int = 1920, max_h: int = 1080) -> str: - """ - Force resize image to fit within max_w x max_h (contain), keeping aspect ratio. - Overwrites the original image or saves a new one. - This prevents huge images from crashing pptx (max 56 inches) and stabilizes MinerU/SAM results. - """ - if not os.path.exists(img_path): - return img_path - - try: - with Image.open(img_path) as img: - w, h = img.size - # If image is already smaller than target box, skip (or optionally upscale?) - # Here we only downscale to avoid quality loss on small icons, - # unless it's way too small? For now, mainly prevent huge images. - if w <= max_w and h <= max_h: - return img_path - - # Calculate scale to contain - scale = min(max_w / w, max_h / h) - new_w = int(w * scale) - new_h = int(h * scale) - - log.info(f"[pdf2ppt_opt] Resizing input image from {w}x{h} to {new_w}x{new_h} (fit {max_w}x{max_h})") - - resized = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - - # Save to a new path to keep original safe (optional, but good practice) - # or overwrite if we want consistency for downstream tools. - # Let's overwrite/update filename to indicate resized. - p = Path(img_path) - new_path = p.parent / f"{p.stem}_resized{p.suffix}" - resized.save(new_path) - return str(new_path) - - except Exception as e: - log.error(f"[pdf2ppt_opt] Resize failed: {e}") - return img_path - -# Helper to construct URLs -def get_sam_urls(): - if os.environ.get("SAM_SERVER_URLS"): - return os.environ.get("SAM_SERVER_URLS").split(",") - sam_cfg = SERVER_CONFIG.get("sam", {}) - instances = sam_cfg.get("instances", []) - if instances: - urls = [] - for inst in instances: - for port in inst.get("ports", []): - urls.append(f"http://127.0.0.1:{port}") - if urls: - return urls - return ["http://localhost:8021", "http://localhost:8022","http://localhost:8023"] - -SAM_SERVER_URLS = get_sam_urls() - -def _ensure_result_path(state: Paper2FigureState) -> str: - raw = getattr(state, "result_path", None) - if raw: - return raw - root = get_project_root() - ts = int(__import__("time").time()) - base_dir = (root / "outputs" / "pdf2ppt_optimized" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - -def _run_sam_on_pages(image_paths: List[str], base_dir: str) -> List[Dict[str, Any]]: - results: List[Dict[str, Any]] = [] - sam_ckpt = f"{get_project_root()}/sam_b.pt" - - for page_idx, img_path in enumerate(image_paths): - img_path_obj = Path(img_path) - if not img_path_obj.exists(): - log.warning(f"[pdf2ppt_opt] image not found for SAM: {img_path}") - results.append({"page_idx": page_idx, "layout_items": []}) - continue - - out_dir = Path(base_dir) / "layout_items" / f"page_{page_idx+1:03d}" - out_dir.mkdir(parents=True, exist_ok=True) - - try: - layout_items = segment_layout_boxes_server( - image_path=str(img_path_obj), - output_dir=str(out_dir), - server_urls=SAM_SERVER_URLS, - checkpoint=sam_ckpt, - min_area=200, - min_score=0.0, - iou_threshold=0.4, - top_k=25, - nms_by="mask", - ) - except Exception as e: - log.error(f"[pdf2ppt_opt] Remote SAM failed: {e}. Fallback to local.") - layout_items = segment_layout_boxes( - image_path=str(img_path_obj), - output_dir=str(out_dir), - checkpoint=sam_ckpt, - min_area=200, - min_score=0.0, - iou_threshold=0.4, - top_k=25, - nms_by="mask", - ) - - try: - pil_img = Image.open(str(img_path_obj)) - w, h = pil_img.size - except Exception as e: - log.error(f"[pdf2ppt_opt][page#{page_idx+1}] open image failed: {e}") - w, h = 1024, 768 - - for it in layout_items: - bbox = it.get("bbox") - if bbox and len(bbox) == 4: - x1n, y1n, x2n, y2n = bbox - x1 = int(round(x1n * w)) - y1 = int(round(y1n * h)) - x2 = int(round(x2n * w)) - y2 = int(round(y2n * h)) - if x2 > x1 and y2 > y1: - it["bbox_px"] = [x1, y1, x2, y2] - - results.append({"page_idx": page_idx, "layout_items": layout_items}) - - try: - free_sam_model(checkpoint=sam_ckpt) - except Exception as e: - log.error(f"[pdf2ppt_opt] free_sam_model failed: {e}") - - return results - -# ============================================================================== -# Helper Functions for PPT Generation -# ============================================================================== - -def _get_dominant_color(img_path: str, bbox: List[int]) -> RGBColor: - """ - Smartly extract text color from the image region. - Logic: - 1. Identify background color (most frequent). - 2. Find text color (frequent color with high contrast to background). - 3. Fallback to Black/White based on background brightness. - """ - try: - with Image.open(img_path) as img: - x1, y1, x2, y2 = bbox - # Clamp - x1 = max(0, x1) - y1 = max(0, y1) - x2 = min(img.width, x2) - y2 = min(img.height, y2) - - if x2 <= x1 or y2 <= y1: - return RGBColor(0, 0, 0) - - crop = img.crop((x1, y1, x2, y2)) - - # Reduce resolution to speed up - crop.thumbnail((50, 50)) - crop = crop.convert("RGB") - - # Get colors (count, (r,g,b)) - colors = crop.getcolors(maxcolors=2500) - if not colors: - return RGBColor(0, 0, 0) - - # Sort by frequency (descending) - # Assumption: The most frequent color is the background - colors.sort(key=lambda x: x[0], reverse=True) - - bg_rgb = colors[0][1] - - def get_brightness(c): - return 0.299 * c[0] + 0.587 * c[1] + 0.114 * c[2] - - def get_dist(c1, c2): - return abs(c1[0]-c2[0]) + abs(c1[1]-c2[1]) + abs(c1[2]-c2[2]) - - bg_brightness = get_brightness(bg_rgb) - - # Find the first color that has enough contrast with background - # We skip the first one (background itself) - best_rgb = None - - for i in range(1, len(colors)): - count, rgb = colors[i] - # Ignore very rare artifacts - if count < colors[0][0] * 0.05: - continue - - # Check contrast - # 1. Brightness diff > 50 - # 2. Or total Euclidean-ish dist > 100 - bri = get_brightness(rgb) - dist = get_dist(rgb, bg_rgb) - - if abs(bri - bg_brightness) > 60 or dist > 150: - best_rgb = rgb - break - - # Fallback if no contrasting color found (e.g. solid block) - if best_rgb is None: - # If background is dark, use white; else black - if bg_brightness < 128: - best_rgb = (255, 255, 255) - else: - best_rgb = (0, 0, 0) - - return RGBColor(best_rgb[0], best_rgb[1], best_rgb[2]) - - except Exception as e: - log.warning(f"[pdf2ppt_opt] Color sampling failed: {e}") - return RGBColor(0, 0, 0) - -def _estimate_font_size_from_slide_geometry_pt( - text: str, - bbox_px: List[int], - text_level: int, - *, - dpi: int = 96, - line_spacing: float = 1.0, -) -> int: - """ - Stable font-size estimation based on *PPT physical size*, not raw pixels. - - We map bbox height (inches) -> points and then shrink by estimated wrapping. - - This fixes the common failure modes: - - FIGURE mode: image is scaled into a fixed 16:9 canvas, bbox_px becomes smaller, - but the PPT physical size stays consistent. Using inches keeps font size stable. - - Huge input images: avoids tiny fonts after downscaling. - - Small images: avoids huge fonts that overflow. - """ - text = (text or "").strip() - if not text: - return 12 - - x1, y1, x2, y2 = bbox_px - box_w_px = max(1, x2 - x1) - box_h_px = max(1, y2 - y1) - - box_w_in = pixels_to_inches(box_w_px, dpi=dpi) - box_h_in = pixels_to_inches(box_h_px, dpi=dpi) - box_w_pt = box_w_in * 72.0 - box_h_pt = box_h_in * 72.0 - - # Base ratio by level (tunable) - if text_level == 1: # title/header - ratio = 0.75 - min_pt, max_pt = 14, 56 - elif text_level == 2: # subtitle - ratio = 0.60 - min_pt, max_pt = 12, 40 - else: # body - ratio = 0.52 - # 8-10pt is often unreadable for screenshots; raise baseline. - min_pt, max_pt = 11, 30 - - pt = max(min_pt, min(max_pt, box_h_pt * ratio)) - - # Shrink by estimated wrapping - # Rough width model: latin ~0.5-0.6em, CJK closer to 1.0em; use 0.75 to be conservative. - char_count = len(text) - if char_count > 0: - chars_per_line = max(1.0, box_w_pt / (pt * 0.75)) - lines_needed = max(1, int(math.ceil(char_count / chars_per_line))) - - max_lines = max(1.0, box_h_pt / (pt * max(0.8, float(line_spacing)))) - if lines_needed > max_lines: - pt = pt * (max_lines / lines_needed) - - pt = max(min_pt, min(max_pt, pt)) - return int(max(6, round(pt))) - - -def _normalize_text_bbox_px( - bbox_px: List[int], - *, - canvas_h_px: int, - text: str, - text_level: int, -) -> List[int]: - """ - MinerU sometimes returns extremely thin text bboxes (few pixels tall), - which makes any font-size estimation unusable (text becomes invisible). - - This function inflates bbox height to a minimum readable band while keeping - the horizontal placement. It is a pragmatic fix for OCR/layout bbox noise. - - UPDATE: Now intelligently calculates required height for long text based on - estimated line wrapping, preventing font from being crushed to 8pt. - - Args: - bbox_px: [x1, y1, x2, y2] in slide-canvas px - canvas_h_px: slide canvas height in px (target_h_px for FIGURE, slide_h_px for PDF) - text_level: 1(title) / 2(subtitle) / 3(body) - """ - x1, y1, x2, y2 = bbox_px - h = max(1, y2 - y1) - w = max(1, x2 - x1) - - # Base minimum bbox height in px on slide canvas - if text_level == 1: - min_h = 30 - elif text_level == 2: - min_h = 22 - else: - min_h = 16 if len((text or "").strip()) >= 20 else 12 - - # For long text, dynamically calculate required height based on wrapping - text_len = len((text or "").strip()) - if text_len > 50 and text_level == 3: # Long body text - # Target font size for estimation: 10pt (~13.3px at 96dpi) - target_pt = 10 - target_px = target_pt * 96.0 / 72.0 - - # Estimate chars per line (conservative: avg char width = 0.75 * font_size) - chars_per_line = max(1.0, w / (target_px * 0.75)) - - # Estimate lines needed - lines_needed = max(1, int(math.ceil(text_len / chars_per_line))) - - # Calculate required height (lines × line_height, line_height = font_size × 1.2) - line_height_px = target_px * 1.2 - required_h = int(lines_needed * line_height_px) - - # Use the larger of base min_h and calculated required_h - min_h = max(min_h, required_h) - - # Cap at reasonable max to avoid excessive boxes - min_h = min(min_h, int(canvas_h_px * 0.4)) - - # Log for diagnosis - log.info(f"[pdf2ppt_opt] Text BBox Normalization: text='{text[:50]}...', level={text_level}, original_bbox={bbox_px}, original_h={h}, min_h={min_h}") - - if h >= min_h: - return bbox_px - - cy = (y1 + y2) / 2.0 - new_y1 = int(round(cy - min_h / 2.0)) - new_y2 = int(round(cy + min_h / 2.0)) - - # Clamp to canvas - new_y1 = max(0, new_y1) - new_y2 = min(int(canvas_h_px), new_y2) - if new_y2 <= new_y1: - new_y2 = min(int(canvas_h_px), new_y1 + 1) - - # Log normalized bbox - log.info(f"[pdf2ppt_opt] Normalized BBox: new_bbox=[{x1}, {new_y1}, {x2}, {new_y2}], new_h={new_y2 - new_y1}") - - return [x1, new_y1, x2, new_y2] - - -def _calculate_font_size(text: str, bbox: List[int], text_level: int = None) -> int: - """ - Stable heuristic estimation (default). - - Previously we used `utils.calculate_font_size()` which directly maps bbox pixel height to pt. - That is unstable across different slide scaling strategies (PDF vs FIGURE). - Now we estimate using slide physical geometry (inches -> pt) so it is consistent. - """ - lvl = text_level or 3 - return int(_estimate_font_size_from_slide_geometry_pt(text=text, bbox_px=bbox, text_level=lvl)) - -def _add_smart_textbox( - slide, - text: str, - bbox: List[int], - text_level: int, - color: RGBColor = None, - *, - slide_w_px: int = None, - slide_h_px: int = None, - enable_render_fit: bool = False, - fitter_dpi: int = 96, -): - """ - Add a text box with stable wrapping + optional render-fit sizing (LibreOffice). - """ - left_in = pixels_to_inches(bbox[0]) - top_in = pixels_to_inches(bbox[1]) - width_in = pixels_to_inches(bbox[2] - bbox[0]) - height_in = pixels_to_inches(bbox[3] - bbox[1]) - - # Baseline heuristic (stable across PDF/FIGURE) - font_size = _calculate_font_size(text, bbox, text_level) - - # Render-fit sizing (LibreOffice) - best for titles/short text. - # IMPORTANT: upper_pt must be in PT, not PX. - if enable_render_fit and slide_w_px and slide_h_px: - try: - style = TextFitStyle( - font_name="Arial", - bold=bool(text_level == 1), - line_spacing=1.0, - margin_px=2, # keep a tiny margin to avoid borderline overflow - ) - - box_h_px = max(1, bbox[3] - bbox[1]) - box_h_in = pixels_to_inches(box_h_px, dpi=fitter_dpi) - upper_pt = max(8, int(box_h_in * 72.0 * 0.95)) - - font_size = DEFAULT_FITTER.fit_font_size_pt( - text=text, - bbox_px=(bbox[0], bbox[1], bbox[2], bbox[3]), - slide_w_px=int(slide_w_px), - slide_h_px=int(slide_h_px), - style=style, - lower_pt=8 if text_level != 1 else 14, # Allow body text to go down to 8pt - upper_pt=upper_pt, - tolerance_px=2, - max_iter=15, - ) - log.info(f"[pdf2ppt_opt] Render Fit: text='{text[:50]}...', bbox={bbox}, calculated_font_size={font_size}") - except Exception as e: - log.warning(f"[pdf2ppt_opt] render-fit failed, fallback to heuristic: {e}") - - # Create textbox - textbox = slide.shapes.add_textbox( - Inches(left_in), Inches(top_in), Inches(width_in), Inches(height_in) - ) - - tf = textbox.text_frame - tf.word_wrap = True - - # Titles: keep stable layout. Body: allow auto-shrink to prevent overflow. - # UPDATE: Since we use strict render-fitting now, we MUST disable auto-fit - # to prevent PPT from overriding our calculated font size (e.g. forcing 8pt). - tf.auto_size = MSO_AUTO_SIZE.NONE - - # Keep margins in sync with fitter margin_px=2 (dpi=96 => ~0.0208in) - m_in = pixels_to_inches(2, dpi=fitter_dpi) - tf.margin_left = Inches(m_in) - tf.margin_right = Inches(m_in) - tf.margin_top = Inches(m_in) - tf.margin_bottom = Inches(m_in) - - p = tf.paragraphs[0] - p.text = text - p.font.size = Pt(int(font_size)) - try: - p.line_spacing = 1.0 - except Exception: - pass - - # Font style - # Use a standard font or keep default - # p.font.name = "Arial" - - if color: - p.font.color.rgb = color - else: - p.font.color.rgb = RGBColor(0, 0, 0) - - if text_level == 1: - p.font.bold = True - p.alignment = PP_ALIGN.CENTER - elif text_level == 2: - p.font.bold = True - -@register("pdf2ppt_optimized") -def create_pdf2ppt_optimized_graph() -> GenericGraphBuilder: - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - def _init_result_path(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - return state - - async def pdf_to_images_node(state: Paper2FigureState) -> Paper2FigureState: - if state.request.input_type == "FIGURE": - img_path = state.request.input_content - # 强制开启 AI 编辑,以便在转 PPT 过程中去除背景文字 - state.use_ai_edit = True - if img_path and os.path.exists(img_path): - state.slide_images = [img_path] - return state - - pdf_path = getattr(state, "pdf_file", None) - if not pdf_path: - log.error("[pdf2ppt_opt] state.pdf_file is empty") - return state - - base_dir = Path(_ensure_result_path(state)) - img_dir = base_dir / "slides_png" - image_paths = ppt_tool.pdf_to_images(pdf_path, str(img_dir)) - state.slide_images = image_paths - return state - - # --- Reused Nodes Logic from parallel wf --- - - async def slides_mineru_node(state: Paper2FigureState) -> Paper2FigureState: - image_paths: List[str] = getattr(state, "slide_images", []) or [] - base_dir = Path(_ensure_result_path(state)) - mineru_dir = base_dir / "mineru_pages" - mineru_dir.mkdir(parents=True, exist_ok=True) - port = getattr(getattr(state, "request", None), "mineru_port", 8010) - - mineru_pages: List[Dict[str, Any]] = [] - - for page_idx, img_path in enumerate(image_paths): - try: - out_dir = mineru_dir / f"page_{page_idx+1:03d}" - out_dir.mkdir(parents=True, exist_ok=True) - - mineru_items = await recursive_mineru_layout( - image_path=str(img_path), - port=port, - max_depth=3, - output_dir=str(out_dir), - ) - mineru_pages.append({ - "page_idx": page_idx, - "blocks": mineru_items, - "path": img_path, - "mineru_output_dir": str(out_dir) - }) - except Exception as e: - log.error(f"[pdf2ppt_opt][MinerU] page#{page_idx+1} failed: {e}") - mineru_pages.append({ - "page_idx": page_idx, - "blocks": [], - "path": img_path, - }) - - state.mineru_pages = mineru_pages - return state - - async def slides_sam_node(state: Paper2FigureState) -> Paper2FigureState: - image_paths: List[str] = getattr(state, "slide_images", []) or [] - base_dir = _ensure_result_path(state) - sam_pages = await asyncio.to_thread(_run_sam_on_pages, image_paths, base_dir) - state.sam_pages = sam_pages - return state - - async def slides_layout_bg_remove_node(state: Paper2FigureState, sam_pages: List[Dict[str, Any]] = None) -> Paper2FigureState: - if sam_pages is None: - sam_pages = getattr(state, "sam_pages", []) or [] - - base_dir = Path(_ensure_result_path(state)) - icons_dir = base_dir / "sam_icons" - icons_dir.mkdir(parents=True, exist_ok=True) - model_path = getattr(getattr(state, "request", None), "bg_rm_model", None) - - def _sync_bg_remove(): - processed = 0 - for p in sam_pages: - page_idx = p.get("page_idx", 0) - for it in p.get("layout_items", []): - png_path = it.get("png_path") - if not png_path or not os.path.exists(png_path): continue - - try: - original_stem = Path(png_path).stem - output_filename = f"page_{page_idx+1:03d}_{original_stem}_bg_removed.png" - output_path = icons_dir / output_filename - - req = {"image_path": png_path, "output_dir": str(icons_dir)} - if model_path: req["model_path"] = model_path - - fg_path = local_tool_for_bg_remove(req) - - if fg_path and os.path.exists(fg_path): - fg_path_obj = Path(fg_path) - if fg_path_obj.name != output_filename: - new_fg_path = fg_path_obj.parent / output_filename - fg_path_obj.rename(new_fg_path) - fg_path = str(new_fg_path) - it["fg_png_path"] = fg_path - else: - it["fg_png_path"] = png_path - processed += 1 - except Exception as e: - log.error(f"[pdf2ppt_opt][bg_rm] failed for {png_path}: {e}") - it["fg_png_path"] = png_path - - try: - if model_path: free_bg_rm_model(model_path=model_path) - except Exception: pass - return processed - - await asyncio.to_thread(_sync_bg_remove) - state.sam_pages = sam_pages - return state - - async def parallel_processing_node(state: Paper2FigureState) -> Paper2FigureState: - import copy - import time - - start_time = time.time() - - async def mineru_branch(): - branch_state = copy.copy(state) - result = await slides_mineru_node(branch_state) - return ("mineru", result) - - async def sam_branch(): - branch_state = copy.copy(state) - branch_state = await slides_sam_node(branch_state) - sam_pages = getattr(branch_state, "sam_pages", []) - branch_state = await slides_layout_bg_remove_node(branch_state, sam_pages=sam_pages) - return ("sam", branch_state) - - results = await asyncio.gather(mineru_branch(), sam_branch(), return_exceptions=True) - - for r in results: - if isinstance(r, Exception): - log.error(f"[pdf2ppt_opt] Branch failed: {r}") - continue - - branch_name, branch_state = r - if branch_name == "mineru": - state.mineru_pages = getattr(branch_state, "mineru_pages", None) - elif branch_name == "sam": - state.sam_pages = getattr(branch_state, "sam_pages", None) - - log.info(f"[pdf2ppt_opt] Parallel processing done in {time.time() - start_time:.2f}s") - return state - - async def slides_ppt_generation_node(state: Paper2FigureState) -> Paper2FigureState: - """ - Final PPT Generation using Hybrid Logic + Dynamic Sizing + Smart Style - """ - # Helper: API Retry - async def _call_image_api_with_retry(coro_factory, retries: int = 3, delay: float = 1.0) -> bool: - last_err = None - for attempt in range(1, retries + 1): - try: - await coro_factory() - return True - except Exception as e: - last_err = e - log.error(f"[pdf2ppt_opt] image api failed attempt {attempt}/{retries}: {e}") - await asyncio.sleep(delay) - log.error(f"[pdf2ppt_opt] image api failed after {retries} attempts: {last_err}") - return False - - sam_pages = getattr(state, "sam_pages", []) or [] - mineru_pages = getattr(state, "mineru_pages", []) or [] - - if not mineru_pages: - log.error("[pdf2ppt_opt] No MinerU pages found! Aborting.") - return state - - # Create dict indices - sam_dict = {p.get("page_idx"): p for p in sam_pages} - - # Determine PPT size from first image - first_img_path = getattr(state, "slide_images", [])[0] if getattr(state, "slide_images", []) else None - - prs = Presentation() - - # ================= 固定 PPT 物理尺寸 & 统一画布分辨率 ================= - # 避免超大输入图导致 PPT 宽度 > 56 英寸 (pptx 限制),同时让字号/布局有稳定参照系。 - # FIGURE / PDF 都统一到固定画布,再映射 bbox。 - is_figure_mode = bool(getattr(getattr(state, "request", None), "input_type", None) == "FIGURE") - - # 物理尺寸:16:9 宽屏,限制在常见 PPT 范围内(不超过 56 in) - target_w_in, target_h_in = 13.333, 7.5 # 16:9 widescreen - - # 画布像素分辨率:统一使用 1280x720 (13.333*96, 7.5*96),方便后续 bbox 映射和字号估算 - # 如果使用 1920x1080,pixels_to_inches 会算出 20英寸 宽,导致内容溢出画布 - canvas_w_px = 1280 - canvas_h_px = 720 - - target_w_px = canvas_w_px - target_h_px = canvas_h_px - - slide_w_px = canvas_w_px - slide_h_px = canvas_h_px - - prs.slide_width = Inches(target_w_in) - prs.slide_height = Inches(target_h_in) - log.info( - f"[pdf2ppt_opt] PPT Size fixed to 16:9: {canvas_w_px}x{canvas_h_px} px (virtual), " - f"{target_w_in}x{target_h_in} in; is_figure_mode={is_figure_mode}" - ) - - # FIGURE 模式下,后面 _map_bbox_px 会把原图 contain 到该画布; - # 非 FIGURE 模式,我们会拉伸映射 (Stretch) 以匹配背景的填充方式。 - - base_dir = Path(_ensure_result_path(state)) - - # --- Pre-Phase: AI Background Generation --- - use_ai_bg = bool(getattr(state, "use_ai_edit", False)) - page_bg_map = {} # page_idx -> bg_path - ai_coroutines = [] - - if use_ai_bg: - log.info(f"[pdf2ppt_opt] AI Edit Enabled. Preparing background cleaning tasks...") - masks_dir = base_dir / "masks" - masks_dir.mkdir(parents=True, exist_ok=True) - bg_dir = base_dir / "clean_backgrounds" - bg_dir.mkdir(parents=True, exist_ok=True) - - # API Config - req_cfg = getattr(state, "request", None) or {} - if not isinstance(req_cfg, dict): - req_cfg = req_cfg.__dict__ if hasattr(req_cfg, "__dict__") else {} - api_key = req_cfg.get("api_key") or os.getenv("DF_API_KEY") - api_url = req_cfg.get("chat_api_url") or "https://api.apiyi.com" - model_name = req_cfg.get("gen_fig_model") or "gemini-3-pro-image-preview" - - if api_key: - for page_data in mineru_pages: - page_idx = page_data.get("page_idx", 0) - img_path = page_data.get("path") - if not img_path or not os.path.exists(img_path): - continue - - try: - clean_bg_path = bg_dir / f"clean_bg_{page_idx+1:03d}.png" - page_bg_map[page_idx] = str(clean_bg_path) - - # New Prompt: Single Image Edit - prompt = "Remove all text from the image, keeping only the background, figures, and icons. Do not change the layout or style. 去除文字,只保留底色 图像 图标" - - async def _run_ai(ip=img_path, op=str(clean_bg_path)): - await _call_image_api_with_retry( - lambda: generate_or_edit_and_save_image_async( - prompt=prompt, - image_path=ip, - save_path=op, - api_url=api_url, - api_key=api_key, - model=model_name, - use_edit=True, - resolution="2K", - timeout=300, - ) - ) - - ai_coroutines.append(_run_ai()) - except Exception as e: - log.error(f"[pdf2ppt_opt] AI Edit setup failed page#{page_idx}: {e}") - - if ai_coroutines: - log.info(f"[pdf2ppt_opt] Waiting for {len(ai_coroutines)} AI tasks...") - await asyncio.gather(*ai_coroutines, return_exceptions=True) - else: - log.warning("[pdf2ppt_opt] use_ai_edit is True but no API Key found.") - - # Coordinate mapping: - # - non-FIGURE (PDF): Stretch mapping (匹配背景的拉伸填充) - # - FIGURE: scale original image into standard 16:9 canvas (contain) with centering - def _map_bbox_px(b: List[int], src_w_px: int, src_h_px: int) -> List[int]: - x1, y1, x2, y2 = b - - if src_w_px <= 0 or src_h_px <= 0: - return b - - if not is_figure_mode: - # PDF Mode: Stretch to fill - # 背景图是用 prs.slide_width/height 拉伸铺满的,所以前景也要同样拉伸 - sx = target_w_px / src_w_px - sy = target_h_px / src_h_px - return [ - int(round(x1 * sx)), - int(round(y1 * sy)), - int(round(x2 * sx)), - int(round(y2 * sy)), - ] - else: - # FIGURE Mode: Contain (Fit within canvas) - s = min(target_w_px / src_w_px, target_h_px / src_h_px) - dx = (target_w_px - src_w_px * s) / 2.0 - dy = (target_h_px - src_h_px * s) / 2.0 - - return [ - int(round(x1 * s + dx)), - int(round(y1 * s + dy)), - int(round(x2 * s + dx)), - int(round(y2 * s + dy)), - ] - - for page_data in mineru_pages: - page_idx = page_data.get("page_idx", 0) - mineru_blocks = page_data.get("blocks", []) - img_path = page_data.get("path") - - # Get corresponding SAM data - sam_data = sam_dict.get(page_idx, {}) - sam_items = sam_data.get("layout_items", []) - - slide = prs.slides.add_slide(prs.slide_layouts[6]) # Blank - - # Per-page source image size (used for mapping) - src_w_px, src_h_px = slide_w_px, slide_h_px - try: - with Image.open(img_path) as _img: - src_w_px, src_h_px = _img.size - except Exception: - pass - - # --- 1. Background --- - clean_bg = page_bg_map.get(page_idx) - if clean_bg and os.path.exists(clean_bg): - try: - # In FIGURE mode we also want the background to be scaled+centered into canvas, - # to keep alignment with mapped bboxes. - if is_figure_mode: - s_bg = min(target_w_px / src_w_px, target_h_px / src_h_px) - w_bg = int(round(src_w_px * s_bg)) - h_bg = int(round(src_h_px * s_bg)) - dx_bg = int(round((target_w_px - w_bg) / 2.0)) - dy_bg = int(round((target_h_px - h_bg) / 2.0)) - slide.shapes.add_picture( - clean_bg, - Inches(pixels_to_inches(dx_bg)), - Inches(pixels_to_inches(dy_bg)), - Inches(pixels_to_inches(w_bg)), - Inches(pixels_to_inches(h_bg)), - ) - else: - slide.shapes.add_picture(clean_bg, 0, 0, prs.slide_width, prs.slide_height) - except Exception as e: - log.error(f"[pdf2ppt_opt] Failed to add clean bg: {e}") - # Fallback to white - bg = slide.background - fill = bg.fill - fill.solid() - fill.fore_color.rgb = RGBColor(255, 255, 255) - else: - # White background (Reconstruction Mode) - bg = slide.background - fill = bg.fill - fill.solid() - fill.fore_color.rgb = RGBColor(255, 255, 255) - - # --- 2. Filter SAM Items (Icons/Logos) --- - # We filter out SAM items that overlap heavily with MinerU text/image blocks - final_sam_items = [] - - # Convert MinerU blocks to pixel bboxes for overlap check (mapped to canvas when FIGURE) - mineru_bboxes_px = [] - for blk in mineru_blocks: - bbox_n = blk.get("bbox") - if bbox_n: - x1 = int(bbox_n[0] * src_w_px) - y1 = int(bbox_n[1] * src_h_px) - x2 = int(bbox_n[2] * src_w_px) - y2 = int(bbox_n[3] * src_h_px) - mineru_bboxes_px.append(_map_bbox_px([x1, y1, x2, y2], src_w_px, src_h_px)) - - def _get_iou(boxA, boxB): - xA = max(boxA[0], boxB[0]) - yA = max(boxA[1], boxB[1]) - xB = min(boxA[2], boxB[2]) - yB = min(boxA[3], boxB[3]) - interArea = max(0, xB - xA) * max(0, yB - yA) - boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) - if boxAArea == 0: return 0 - return interArea / boxAArea - - for item in sam_items: - s_bbox_raw = item.get("bbox_px") - if not s_bbox_raw: continue - - # FIX: Map SAM bbox (source coords) to Canvas coords for correct IOU check - # MinerU bboxes are already mapped to canvas in mineru_bboxes_px - s_bbox_mapped = _map_bbox_px( - [s_bbox_raw[0], s_bbox_raw[1], s_bbox_raw[2], s_bbox_raw[3]], - src_w_px, - src_h_px - ) - - # Check overlap with MinerU blocks - is_overlap = False - for m_bbox in mineru_bboxes_px: - if _get_iou(s_bbox_mapped, m_bbox) > 0.5: - is_overlap = True - break - - if not is_overlap: - final_sam_items.append(item) - - # Render SAM items (mapped in FIGURE mode) - for item in final_sam_items: - fg_path = item.get("fg_png_path") - raw_path = item.get("png_path") - - # Log for diagnosis - log.info(f"[pdf2ppt_opt] SAM Item: raw_path={raw_path}, fg_path={fg_path}, exists_fg={os.path.exists(fg_path) if fg_path else False}") - - # Only render if fg_path exists (background removed successfully) - if fg_path and os.path.exists(fg_path): - ipath = fg_path - else: - log.warning(f"[pdf2ppt_opt] Skipping SAM item due to missing fg_png_path: {raw_path}") - continue - - bbox = item.get("bbox_px") - if not bbox: - continue - mb = _map_bbox_px([bbox[0], bbox[1], bbox[2], bbox[3]], src_w_px, src_h_px) - slide.shapes.add_picture( - ipath, - Inches(pixels_to_inches(mb[0])), - Inches(pixels_to_inches(mb[1])), - Inches(pixels_to_inches(mb[2] - mb[0])), - Inches(pixels_to_inches(mb[3] - mb[1])), - ) - - # --- 3. Process MinerU Blocks (Text & Images) --- - for blk in mineru_blocks: - btype = (blk.get("type") or "").lower() - bbox_n = blk.get("bbox") - if not bbox_n: continue - - # Convert to pixels (source image space) then map to slide canvas if FIGURE - x1 = int(bbox_n[0] * src_w_px) - y1 = int(bbox_n[1] * src_h_px) - x2 = int(bbox_n[2] * src_w_px) - y2 = int(bbox_n[3] * src_h_px) - bbox_px = _map_bbox_px([x1, y1, x2, y2], src_w_px, src_h_px) - - if x2 <= x1 or y2 <= y1: continue - - # A. Images/Tables -> Render directly (using MinerU crops or fallback) - if btype in ['image', 'figure', 'table', 'formula']: - # Try to find existing image path - ipath = blk.get("img_path") - # If not in block, maybe MinerU saved sub-images. - # For now, simple fallback: crop from original - if not ipath or not os.path.exists(ipath): - # Use fallback crop - fallback_dir = base_dir / "mineru_fallback" - fallback_dir.mkdir(parents=True, exist_ok=True) - ipath = str(fallback_dir / f"crop_{page_idx}_{id(blk)}.png") - if not os.path.exists(ipath): - try: - with Image.open(img_path) as page_img: - crop = page_img.crop((x1, y1, x2, y2)) - crop.save(ipath) - except: - ipath = None - - if ipath and os.path.exists(ipath): - slide.shapes.add_picture( - ipath, - Inches(pixels_to_inches(bbox_px[0])), - Inches(pixels_to_inches(bbox_px[1])), - Inches(pixels_to_inches(bbox_px[2] - bbox_px[0])), - Inches(pixels_to_inches(bbox_px[3] - bbox_px[1])), - ) - - # B. Text/Title -> Use MinerU Text - elif btype in ['text', 'title', 'header', 'footer', 'reference', 'list']: - text_level = 1 if btype in ['title', 'header'] else 3 - if btype == 'text': - text_level = 3 - - final_text = blk.get("text") or blk.get("content") or "" - - if not final_text.strip(): - continue - - # canvas size in px for bbox normalization & fitter coordinate system - canvas_w_px_for_fit = target_w_px if is_figure_mode else slide_w_px - canvas_h_px_for_fit = target_h_px if is_figure_mode else slide_h_px - - # 1. Fix pathological thin bboxes -> invisible text - bbox_px = _normalize_text_bbox_px( - bbox_px, - canvas_h_px=int(canvas_h_px_for_fit), - text=final_text, - text_level=text_level, - ) - - # 2. Color Sampling - text_color = _get_dominant_color(img_path, bbox_px) - - # 3. Add Text Box - # Enable render-fit for ALL text to ensure it fits within bbox - enable_render_fit = True - _add_smart_textbox( - slide, - final_text, - bbox_px, - text_level, - text_color, - slide_w_px=canvas_w_px_for_fit, - slide_h_px=canvas_h_px_for_fit, - enable_render_fit=enable_render_fit, - fitter_dpi=96, - ) - - # Save PPT - ppt_path = base_dir / "pdf2ppt_optimized_output.pptx" - prs.save(str(ppt_path)) - state.ppt_path = str(ppt_path) - log.info(f"[pdf2ppt_opt] PPT generated at: {ppt_path}") - - return state - - nodes = { - "_start_": _init_result_path, - "pdf_to_images": pdf_to_images_node, - "parallel_processing": parallel_processing_node, - "slides_ppt_generation": slides_ppt_generation_node, - "_end_": lambda state: state, - } - - edges = [ - ("pdf_to_images", "parallel_processing"), - ("parallel_processing", "slides_ppt_generation"), - ("slides_ppt_generation", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges) - builder.add_edge("_start_", "pdf_to_images") - return builder diff --git a/dataflow_agent/workflow/wf_pdf2ppt_parallel.py b/dataflow_agent/workflow/wf_pdf2ppt_parallel.py deleted file mode 100644 index 32b322b..0000000 --- a/dataflow_agent/workflow/wf_pdf2ppt_parallel.py +++ /dev/null @@ -1,1072 +0,0 @@ -""" -pdf2ppt_with_sam workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -基于 slides PDF: -1. 将 PDF 每页渲染为 PNG -2. 对每页图片用 PaddleOCR 做文字 OCR -3. 对每页图片用 MinerU 做版面分析(区分 Text vs Image/Table) -4. 对每页图片用 SAM 做图标 / 图块分割 -5. 智能合并: - - MinerU 划定 "图表区" (Image/Table) 和 "正文区"。 - - OCR 文本如果落在 "图表区" 则丢弃,防止图片上的文字重复生成。 - - SAM 图块如果落在 "图表区" 则丢弃(由 MinerU 负责);如果在 "正文区" 且包含文字则丢弃(防止把文字当图); - 剩下的 SAM 块被视为 "无字图标",进行抠图后保留。 - - MinerU 提取的图片直接复用其 sub_images 目录,不再手动裁剪。 - - 字体归一化:全局统计正文和标题字号,强制统一,保证整齐。 - - 使用 AI Inpainting 生成干净背景。 -""" - -from __future__ import annotations -import os -import asyncio -from pathlib import Path -from typing import List, Dict, Any, Optional -from collections import Counter - -import cv2 -import numpy as np -import fitz # PyMuPDF -import yaml -from PIL import Image - -from dataflow_agent.workflow.registry import register -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.utils import get_project_root - -# Tools -from dataflow_agent.toolkits.multimodaltool.sam_tool import segment_layout_boxes, segment_layout_boxes_server, free_sam_model -from dataflow_agent.toolkits.multimodaltool.bg_tool import local_tool_for_bg_remove, free_bg_rm_model -from dataflow_agent.toolkits.multimodaltool.mineru_tool import recursive_mineru_layout -from dataflow_agent.toolkits.multimodaltool.req_img import gemini_multi_image_edit_async -from dataflow_agent.toolkits.multimodaltool import ppt_tool - -from pptx import Presentation -from pptx.util import Inches, Pt -from pptx.dml.color import RGBColor - -log = get_logger(__name__) - -# Load configuration from yaml -def load_server_config(): - root = get_project_root() - config_path = root / "conf" / "model_servers.yaml" - if not config_path.exists(): - log.warning(f"Config file not found at {config_path}, using defaults.") - return {} - try: - with open(config_path, "r") as f: - return yaml.safe_load(f) or {} - except Exception as e: - log.error(f"Failed to load config: {e}") - return {} - -SERVER_CONFIG = load_server_config() - -# Helper to construct URLs -def get_sam_urls(): - # Check env var first - if os.environ.get("SAM_SERVER_URLS"): - return os.environ.get("SAM_SERVER_URLS").split(",") - - # Try config - sam_cfg = SERVER_CONFIG.get("sam", {}) - instances = sam_cfg.get("instances", []) - if instances: - urls = [] - for inst in instances: - for port in inst.get("ports", []): - urls.append(f"http://127.0.0.1:{port}") - if urls: - return urls - - # Default - return ["http://localhost:8021", "http://localhost:8022","http://localhost:8023"] - -def get_ocr_urls(): - # Check env var first - if os.environ.get("OCR_SERVER_URLS"): - return os.environ.get("OCR_SERVER_URLS").split(",") - - # Try config - ocr_cfg = SERVER_CONFIG.get("ocr", {}) - if ocr_cfg: - host = ocr_cfg.get("host", "0.0.0.0") - if host == "0.0.0.0": host = "127.0.0.1" - port = ocr_cfg.get("port", 8003) - return [f"http://{host}:{port}"] - - # Default - return ["http://localhost:8003"] - -SAM_SERVER_URLS = get_sam_urls() -OCR_SERVER_URLS = get_ocr_urls() - - -def _ensure_result_path(state: Paper2FigureState) -> str: - """ - 为本次 pdf2ppt_with_sam workflow 创建统一的输出目录: - - 如果 state.result_path 已存在,直接使用; - - 否则使用项目根目录下 outputs/pdf2ppt_with_sam/。 - """ - raw = getattr(state, "result_path", None) - if raw: - return raw - - root = get_project_root() - ts = int(__import__("time").time()) - base_dir = (root / "outputs" / "pdf2ppt_with_sam" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - - -def _run_sam_on_pages(image_paths: List[str], base_dir: str) -> List[Dict[str, Any]]: - """ - 对每一页图片运行 SAM,输出 layout_items。 - """ - results: List[Dict[str, Any]] = [] - sam_ckpt = f"{get_project_root()}/sam_b.pt" - - for page_idx, img_path in enumerate(image_paths): - img_path_obj = Path(img_path) - if not img_path_obj.exists(): - log.warning(f"[pdf2ppt_with_sam] image not found for SAM: {img_path}") - results.append({"page_idx": page_idx, "layout_items": []}) - continue - - out_dir = Path(base_dir) / "layout_items" / f"page_{page_idx+1:03d}" - out_dir.mkdir(parents=True, exist_ok=True) - - # 1. SAM 分割 (使用远程服务) - try: - layout_items = segment_layout_boxes_server( - image_path=str(img_path_obj), - output_dir=str(out_dir), - server_urls=SAM_SERVER_URLS, - checkpoint=sam_ckpt, - min_area=200, - min_score=0.0, - iou_threshold=0.4, - top_k=15, - nms_by="mask", - ) - except Exception as e: - log.error(f"[pdf2ppt_with_sam] Remote SAM failed: {e}. Fallback to local.") - # Fallback to local if server fails - layout_items = segment_layout_boxes( - image_path=str(img_path_obj), - output_dir=str(out_dir), - checkpoint=sam_ckpt, - min_area=200, - min_score=0.0, - iou_threshold=0.4, - top_k=15, - nms_by="mask", - ) - - log.info(f"[pdf2ppt_with_sam][page#{page_idx+1}] SAM found {len(layout_items)} items") - - # 2. 映射 bbox 到像素坐标(基于整页尺寸) - try: - pil_img = Image.open(str(img_path_obj)) - w, h = pil_img.size - except Exception as e: - log.error(f"[pdf2ppt_with_sam][page#{page_idx+1}] open image failed: {e}") - w, h = 1024, 768 - - for it in layout_items: - bbox = it.get("bbox") - if bbox and len(bbox) == 4: - x1n, y1n, x2n, y2n = bbox - x1 = int(round(x1n * w)) - y1 = int(round(y1n * h)) - x2 = int(round(x2n * w)) - y2 = int(round(y2n * h)) - if x2 > x1 and y2 > y1: - it["bbox_px"] = [x1, y1, x2, y2] - - results.append({"page_idx": page_idx, "layout_items": layout_items}) - - # 显式释放 SAM 模型 - try: - free_sam_model(checkpoint=sam_ckpt) - except Exception as e: - log.error(f"[pdf2ppt_with_sam] free_sam_model failed: {e}") - - return results - - -@register("pdf2ppt_parallel") -def create_pdf2ppt_with_sam_graph() -> GenericGraphBuilder: # noqa: N802 - """ - Workflow factory: dfa run --wf pdf2ppt_with_sam - """ - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - # ============================== - # NODES - # ============================== - - def _init_result_path(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - return state - - async def pdf_to_images_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 将 PDF 每一页渲染为 PNG。 - 如果输入是 FIGURE (图片模式),直接使用 input_content 作为 slide_images。 - """ - if state.request.input_type == "FIGURE": - img_path = state.request.input_content - log.info(f"[pdf2ppt_with_sam] FIGURE mode: using input image {img_path}") - - # 强制开启 AI 编辑,以便在转 PPT 过程中去除背景文字 - state.use_ai_edit = True - - if img_path and os.path.exists(img_path): - state.slide_images = [img_path] - else: - log.error(f"[pdf2ppt_with_sam] FIGURE mode: image not found {img_path}") - return state - - pdf_path = getattr(state, "pdf_file", None) - if not pdf_path: - log.error("[pdf2ppt_with_sam] state.pdf_file is empty") - return state - - base_dir = Path(_ensure_result_path(state)) - img_dir = base_dir / "slides_png" - image_paths = ppt_tool.pdf_to_images(pdf_path, str(img_dir)) - state.slide_images = image_paths - return state - - async def slides_ocr_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 对每一页图片用 PaddleOCR 做 OCR。 - 使用 asyncio.to_thread 包装同步调用,避免阻塞事件循环。 - """ - image_paths: List[str] = getattr(state, "slide_images", []) or [] - if not image_paths: - log.error("[pdf2ppt_with_sam] no slide_images for OCR") - return state - - def _sync_ocr_all_pages(): - """同步执行所有页面的 OCR""" - ocr_pages: List[Dict[str, Any]] = [] - for page_idx, img_path in enumerate(image_paths): - try: - # 优先使用远程 OCR 服务 - try: - result = ppt_tool.paddle_ocr_page_with_layout_server(img_path, server_urls=OCR_SERVER_URLS) - except Exception as e: - log.warning(f"[pdf2ppt_with_sam][OCR] remote failed: {e}. Fallback to local.") - result = ppt_tool.paddle_ocr_page_with_layout(img_path) - except Exception as e: - log.error(f"[pdf2ppt_with_sam][OCR] page#{page_idx+1} failed: {e}") - result = { - "image_size": None, - "lines": [], - "body_h_px": None, - "bg_color": None, - "path": img_path, - "page_idx": page_idx, - } - result["page_idx"] = page_idx - result["path"] = img_path - ocr_pages.append(result) - return ocr_pages - - # 在线程池中执行同步 OCR,不阻塞事件循环 - ocr_pages = await asyncio.to_thread(_sync_ocr_all_pages) - state.ocr_pages = ocr_pages - return state - - async def slides_mineru_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 对每一页 PNG 使用 MinerU 做版面识别: - - 输出每页的 mineru_items,包含 type / bbox(norm) / text 等 - """ - image_paths: List[str] = getattr(state, "slide_images", []) or [] - if not image_paths: - log.error("[pdf2ppt_with_sam] no slide_images for MinerU") - return state - - base_dir = Path(_ensure_result_path(state)) - mineru_dir = base_dir / "mineru_pages" - mineru_dir.mkdir(parents=True, exist_ok=True) - - # MinerU 端口,优先从 state.request.mineru_port 读取 - # MinerU LB Port 8010 - port = getattr(getattr(state, "request", None), "mineru_port", 8010) - # 复杂度深度可从 state 或常量 - max_depth = getattr(state, "mask_detail_level", 3) - - mineru_pages: List[Dict[str, Any]] = [] - - for page_idx, img_path in enumerate(image_paths): - try: - out_dir = mineru_dir / f"page_{page_idx+1:03d}" - out_dir.mkdir(parents=True, exist_ok=True) - - log.critical(f"【mineru node】: {out_dir}") - - mineru_items = await recursive_mineru_layout( - image_path=str(img_path), - port=port, - max_depth=3, - output_dir=str(out_dir), - ) - - # 记录 MinerU 输出目录,方便后续找 sub_images - # recursive_mineru_layout 会在 out_dir 下直接输出或创建子目录 - # 这里我们记录 out_dir,后续可以在里面找 sub_images - - mineru_pages.append({ - "page_idx": page_idx, - "blocks": mineru_items, - "path": img_path, - "mineru_output_dir": str(out_dir) - }) - log.info(f"[pdf2ppt_with_sam][MinerU] page#{page_idx+1} got {len(mineru_items)} blocks") - except Exception as e: - log.error(f"[pdf2ppt_with_sam][MinerU] page#{page_idx+1} failed: {e}") - mineru_pages.append({ - "page_idx": page_idx, - "blocks": [], - "path": img_path, - }) - - state.mineru_pages = mineru_pages - - log.critical(f"[state.mineru_pages]: {state.mineru_pages}") - - return state - - async def slides_sam_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 对每一页图片运行 SAM 用于图标 / 图块分割。 - 使用 asyncio.to_thread 包装同步调用,避免阻塞事件循环。 - """ - image_paths: List[str] = getattr(state, "slide_images", []) or [] - if not image_paths: - log.error("[pdf2ppt_with_sam] no slide_images for SAM") - return state - - base_dir = _ensure_result_path(state) - - # 在线程池中执行同步 SAM,不阻塞事件循环 - sam_pages = await asyncio.to_thread(_run_sam_on_pages, image_paths, base_dir) - state.sam_pages = sam_pages - return state - - async def slides_layout_bg_remove_node(state: Paper2FigureState, sam_pages: List[Dict[str, Any]] = None) -> Paper2FigureState: - """ - 对每一页 SAM layout PNG 做背景抠图: - - 输入: state.sam_pages[*].layout_items[].png_path 或传入的 sam_pages - - 输出: 为每个 layout_item 写入 fg_png_path(抠完背景的 PNG) - 使用 asyncio.to_thread 包装同步调用,避免阻塞事件循环。 - """ - # 支持从参数传入 sam_pages(用于并行分支) - if sam_pages is None: - sam_pages = getattr(state, "sam_pages", []) or [] - - if not sam_pages: - log.error("[pdf2ppt_with_sam] no sam_pages for bg remove") - return state - - base_dir = Path(_ensure_result_path(state)) - icons_dir = base_dir / "sam_icons" - icons_dir.mkdir(parents=True, exist_ok=True) - - model_path = getattr(getattr(state, "request", None), "bg_rm_model", None) - - def _sync_bg_remove(): - """同步执行所有背景移除""" - processed = 0 - for p in sam_pages: - page_idx = p.get("page_idx", 0) - for it in p.get("layout_items", []): - png_path = it.get("png_path") - if not png_path or not os.path.exists(png_path): - continue - - # 背景抠图 - 添加页码前缀避免文件名冲突 - try: - # 从原始路径提取文件名 - original_stem = Path(png_path).stem - # 创建带页码的输出文件名 - output_filename = f"page_{page_idx+1:03d}_{original_stem}_bg_removed.png" - output_path = icons_dir / output_filename - - req = { - "image_path": png_path, - "output_dir": str(icons_dir), - } - if model_path: - req["model_path"] = model_path - - fg_path = local_tool_for_bg_remove(req) - - # 重命名文件以包含页码 - if fg_path and os.path.exists(fg_path): - # 将生成的文件重命名为带页码的文件名 - fg_path_obj = Path(fg_path) - if fg_path_obj.name != output_filename: - new_fg_path = fg_path_obj.parent / output_filename - fg_path_obj.rename(new_fg_path) - fg_path = str(new_fg_path) - - it["fg_png_path"] = fg_path - else: - it["fg_png_path"] = png_path - - processed += 1 - except Exception as e: - log.error(f"[pdf2ppt_with_sam][bg_rm] failed for {png_path}: {e}") - it["fg_png_path"] = png_path - - # 抠图完成后可尝试释放模型(忽略失败) - try: - if model_path: - free_bg_rm_model(model_path=model_path) - except Exception as e: - log.error(f"[pdf2ppt_with_sam] free_bg_rm_model failed: {e}") - - return processed - - # 在线程池中执行同步背景移除,不阻塞事件循环 - processed = await asyncio.to_thread(_sync_bg_remove) - - log.info(f"[pdf2ppt_with_sam] bg remove processed: {processed} items") - - # 将处理后的 sam_pages 写回 state - state.sam_pages = sam_pages - return state - - # ============================================================== - # 并行处理节点:同时执行 OCR、MinerU、SAM+背景移除 - # ============================================================== - async def parallel_processing_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 并行执行三个分支: - 1. slides_ocr_node -> ocr_pages - 2. slides_mineru_node -> mineru_pages - 3. slides_sam_node + slides_layout_bg_remove_node -> sam_pages - - 三个分支共享 state.slide_images 作为输入,各自写入不同的输出字段。 - """ - import copy - import time - - log.info("[parallel_processing] 开始并行处理 OCR / MinerU / SAM+BgRemove ...") - start_time = time.time() - - # 定义三个分支任务 - async def ocr_branch(): - """OCR 分支""" - log.info("[parallel_processing][OCR] 分支启动") - branch_state = copy.copy(state) # 浅拷贝,共享 slide_images - result = await slides_ocr_node(branch_state) - log.info(f"[parallel_processing][OCR] 分支完成,提取了 {len(getattr(result, 'ocr_pages', []))} 页") - return ("ocr", result) - - async def mineru_branch(): - """MinerU 分支""" - log.info("[parallel_processing][MinerU] 分支启动") - branch_state = copy.copy(state) - result = await slides_mineru_node(branch_state) - log.info(f"[parallel_processing][MinerU] 分支完成,提取了 {len(getattr(result, 'mineru_pages', []))} 页") - return ("mineru", result) - - async def sam_branch(): - """SAM + 背景移除 分支(串行执行)""" - log.info("[parallel_processing][SAM] 分支启动") - branch_state = copy.copy(state) - - # 先执行 SAM - branch_state = await slides_sam_node(branch_state) - sam_pages = getattr(branch_state, "sam_pages", []) - log.info(f"[parallel_processing][SAM] SAM 完成,提取了 {len(sam_pages)} 页") - - # 再执行背景移除 - branch_state = await slides_layout_bg_remove_node(branch_state, sam_pages=sam_pages) - log.info("[parallel_processing][SAM] 背景移除完成") - - return ("sam", branch_state) - - # 并行执行三个分支 - results = await asyncio.gather( - ocr_branch(), - mineru_branch(), - sam_branch(), - return_exceptions=True - ) - - # 合并结果到 state - for r in results: - if isinstance(r, Exception): - log.error(f"[parallel_processing] 分支执行失败: {r}") - import traceback - traceback.print_exc() - continue - - branch_name, branch_state = r - - if branch_name == "ocr": - ocr_pages = getattr(branch_state, "ocr_pages", None) - if ocr_pages: - state.ocr_pages = ocr_pages - log.info(f"[parallel_processing] 合并 OCR 结果: {len(ocr_pages)} 页") - - elif branch_name == "mineru": - mineru_pages = getattr(branch_state, "mineru_pages", None) - if mineru_pages: - state.mineru_pages = mineru_pages - log.info(f"[parallel_processing] 合并 MinerU 结果: {len(mineru_pages)} 页") - - elif branch_name == "sam": - sam_pages = getattr(branch_state, "sam_pages", None) - if sam_pages: - state.sam_pages = sam_pages - log.info(f"[parallel_processing] 合并 SAM 结果: {len(sam_pages)} 页") - - elapsed = time.time() - start_time - log.info(f"[parallel_processing] 并行处理完成,耗时 {elapsed:.2f}s") - - return state - - async def slides_ppt_generation_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 结合 MinerU + OCR + SAM 结果生成可编辑 PPT: - - 改进点: - 1. MinerU 图片渲染修复:优先复用 MinerU 输出目录下的 sub_images,无法匹配时再手动裁剪。 - 2. 字体归一化: - - 统计全页正文(Body)文本的平均字号,取众数作为标准正文字号。 - - 标题(Title)字号设为标准正文的 1.5 倍(或取 Title 众数)。 - - 强制所有 Body 文本使用 Standard Body Font,所有 Title 文本使用 Standard Title Font。 - 3. 背景生成开关: - - 使用 state.use_ai_edit 控制是否调用 AI 生成纯净背景; - - 关闭时直接使用纯白背景。 - 4. 并行 API 调用: - - 将 Inpainting API 调用改为并行执行,加快多页处理速度。 - """ - - ocr_pages: List[Dict[str, Any]] = getattr(state, "ocr_pages", []) or [] - sam_pages: List[Dict[str, Any]] = getattr(state, "sam_pages", []) or [] - mineru_pages: List[Dict[str, Any]] = getattr(state, "mineru_pages", []) or [] - - if not ocr_pages: - log.error("[pdf2ppt_with_sam] no ocr_pages, abort PPT generation") - return state - - # 建立索引 - sam_dict = {p.get("page_idx", 0): p.get("layout_items", []) for p in sam_pages} - - # mineru_dict 存放 {"blocks": [], "mineru_output_dir": ...} - # 修复:为了防止 page_idx 类型不一致 (int vs str),构建更鲁棒的索引 - mineru_dict = {} - for p in mineru_pages: - pid = p.get("page_idx", 0) - mineru_dict[pid] = p # 原始类型 - mineru_dict[str(pid)] = p # 字符串类型兼容 - - # 以 PPT 工具里的默认比例创建 Presentation - prs = Presentation() - prs.slide_width = Inches(ppt_tool.SLIDE_W_IN) - prs.slide_height = Inches(ppt_tool.SLIDE_H_IN) - slide_w_emu = prs.slide_width - slide_h_emu = prs.slide_height - - # 初始化 base_dir,确保后续逻辑都能访问 - base_dir = Path(_ensure_result_path(state)) - - # ========================================================== - # 辅助函数:API 重试逻辑 - # ========================================================== - async def _call_image_api_with_retry(coro_factory, retries: int = 3, delay: float = 1.0) -> bool: - """ - 对图像生成/编辑进行最多 retries 次重试。 - """ - last_err: Optional[Exception] = None - for attempt in range(1, retries + 1): - try: - await coro_factory() - return True - except Exception as e: - last_err = e - log.error(f"[pdf2ppt_with_sam] image api failed attempt {attempt}/{retries}: {e}") - if attempt < retries: - try: - await asyncio.sleep(delay) - except Exception: - pass - log.error(f"[pdf2ppt_with_sam] image api failed after {retries} attempts: {last_err}") - return False - - # ========================================================== - # 辅助函数:字体和几何计算 - # ========================================================== - def _bbox_area(bbox): - return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1]) - - def _get_intersection_area(bbox1, bbox2): - x1 = max(bbox1[0], bbox2[0]) - y1 = max(bbox1[1], bbox2[1]) - x2 = min(bbox1[2], bbox2[2]) - y2 = min(bbox1[3], bbox2[3]) - return max(0, x2 - x1) * max(0, y2 - y1) - - def _is_inside(inner, outer, threshold=0.9): - inter = _get_intersection_area(inner, outer) - inner_a = _bbox_area(inner) - if inner_a <= 0: return False - return (inter / inner_a) >= threshold - - def _is_overlap(bbox1, bbox2, threshold=0.1): - inter = _get_intersection_area(bbox1, bbox2) - min_area = min(_bbox_area(bbox1), _bbox_area(bbox2)) - if min_area <= 0: return False - return (inter / min_area) >= threshold - - # ========================================================== - # Phase 1: 准备渲染数据 & 创建 AI 任务 - # ========================================================== - - pages_render_data: List[Dict[str, Any]] = [] - ai_coroutines = [] # List of awaitables - - # 循环处理每一页的布局分析 - for pinfo in ocr_pages: - page_idx = pinfo.get("page_idx", 0) - - # 兼容性查找 - mineru_page_data = mineru_dict.get(page_idx) - if not mineru_page_data: - mineru_page_data = mineru_dict.get(str(page_idx), {}) - if mineru_page_data: - log.warning(f"[pdf2ppt_with_sam] page_idx mismatch fixed by str conversion: {page_idx}") - - img_path = pinfo.get("path") - lines = pinfo.get("lines", []) # List of (bbox, text, conf) - - if not img_path or not os.path.exists(img_path): - log.warning(f"[pdf2ppt_with_sam] missing img for page#{page_idx+1}: {img_path}") - continue - - # 读取原始图像信息 - try: - pil_img = Image.open(img_path) - w0, h0 = pil_img.size - except Exception as e: - log.error(f"Failed to open image {img_path}: {e}") - continue - - # ----------------------------------------------------------- - # Step 1: 分析 MinerU 结果,划定 "Image Zone" 并找回 sub_images - # ----------------------------------------------------------- - mineru_blocks = mineru_page_data.get("blocks", []) - mineru_out_dir = mineru_page_data.get("mineru_output_dir") - - image_zones = [] # List of {"bbox": [x1,y1,x2,y2], "type": str, "img_path": str} - - # 尝试定位 sub_images 目录 - sub_images_dir = None - sub_images_dirs: List[Path] = [] - if mineru_out_dir: - try: - page_root = Path(mineru_out_dir) - direct = page_root / "sub_images" - if direct.exists() and direct.is_dir(): - sub_images_dirs.append(direct) - for d in page_root.rglob("sub_images"): - if d.is_dir(): - sub_images_dirs.append(d) - seen = set() - unique_dirs: List[Path] = [] - for d in sub_images_dirs: - rp = str(d.resolve()) - if rp not in seen: - seen.add(rp) - unique_dirs.append(d) - for d in unique_dirs: - pngs = list(d.glob("*.png")) - if pngs: - sub_images_dir = d - break - if sub_images_dir: - sub_files = sorted([p.name for p in sub_images_dir.glob("*.png")]) - log.info(f"[pdf2ppt_with_sam][page#{page_idx+1}] MinerU sub_images dir: {sub_images_dir}, found {len(sub_files)} pngs") - except Exception as e: - log.error(f"[pdf2ppt_with_sam][page#{page_idx+1}] search sub_images failed: {e}") - - for idx, blk in enumerate(mineru_blocks): - btype = (blk.get("type") or "").lower() - bbox = blk.get("bbox") # norm - if not bbox or len(bbox) != 4: - continue - - x1 = int(round(bbox[0] * w0)) - y1 = int(round(bbox[1] * h0)) - x2 = int(round(bbox[2] * w0)) - y2 = int(round(bbox[3] * h0)) - - if x2 <= x1 or y2 <= y1: continue - px_bbox = [x1, y1, x2, y2] - - is_image_zone = btype in ['image', 'figure', 'table', 'formula'] - img_path_found = None - - if is_image_zone: - if blk.get("img_path") and os.path.exists(blk["img_path"]): - img_path_found = blk["img_path"] - - if not img_path_found and sub_images_dir: - try: - depth = blk.get("depth", 0) - try: - depth = int(depth) - except Exception: - depth = 0 - prefix = f"depth{depth}_blk{idx}_" - for f in sorted(sub_images_dir.glob("*.png")): - if f.name.startswith(prefix): - img_path_found = str(f.resolve()) - break - except Exception as e: - log.error(f"[pdf2ppt_with_sam][page#{page_idx+1}] match sub_images failed: {e}") - - if not img_path_found: - fallback_dir = base_dir / "mineru_fallback_crops" / f"page_{page_idx+1:03d}" - fallback_dir.mkdir(parents=True, exist_ok=True) - save_path = fallback_dir / f"mineru_{idx}_{btype}.png" - try: - if not save_path.exists(): - crop = pil_img.crop((x1, y1, x2, y2)) - crop.save(save_path) - img_path_found = str(save_path) - except Exception as e: - log.error(f"Failed to crop mineru block {idx}: {e}") - - if img_path_found: - image_zones.append({ - "bbox": px_bbox, - "type": btype, - "img_path": img_path_found - }) - - # ----------------------------------------------------------- - # Step 2: 过滤 OCR 文字 - # ----------------------------------------------------------- - final_ocr_lines = [] # (bbox, text, conf, type, raw_pt) - - for line in lines: - l_bbox, l_text, l_conf = line - is_in_image = False - for zone in image_zones: - if _is_inside(l_bbox, zone["bbox"]): - is_in_image = True - break - - if not is_in_image: - l_type = "body" - for blk in mineru_blocks: - btype = (blk.get("type") or "").lower() - b_bbox = blk.get("bbox") - if not b_bbox: continue - bx1 = int(round(b_bbox[0] * w0)) - by1 = int(round(b_bbox[1] * h0)) - bx2 = int(round(b_bbox[2] * w0)) - by2 = int(round(b_bbox[3] * h0)) - - if btype in ['title', 'header'] and _is_inside(l_bbox, [bx1, by1, bx2, by2]): - l_type = "title" - break - - # 预先计算原始字号,方便后续聚类 - raw_pt_obj = ppt_tool.estimate_font_pt(l_bbox, img_h_px=h0, body_h_px=None) - raw_pt = raw_pt_obj.pt if hasattr(raw_pt_obj, "pt") else raw_pt_obj - - final_ocr_lines.append((l_bbox, l_text, l_conf, l_type, raw_pt)) - - # ----------------------------------------------------------- - # Step 3: 过滤 SAM 图块 - # ----------------------------------------------------------- - raw_sam_items = sam_dict.get(page_idx, []) - final_sam_items = [] - - for item in raw_sam_items: - s_bbox = item.get("bbox_px") - if not s_bbox: continue - is_in_image = False - for zone in image_zones: - if _is_inside(s_bbox, zone["bbox"], threshold=0.6): - is_in_image = True - break - if is_in_image: continue - - is_text_block = False - for line in final_ocr_lines: - l_bbox = line[0] - if _is_overlap(s_bbox, l_bbox, threshold=0.3) or _is_inside(l_bbox, s_bbox): - is_text_block = True - break - if is_text_block: continue - - w = s_bbox[2] - s_bbox[0] - h = s_bbox[3] - s_bbox[1] - if w < 5 or h < 5: continue - if w*h < 400: continue - - final_sam_items.append(item) - - # ----------------------------------------------------------- - # Step 4: 准备 AI 背景生成任务 - # ----------------------------------------------------------- - clean_bg_path = base_dir / "clean_backgrounds" / f"clean_bg_{page_idx+1:03d}.png" - clean_bg_path.parent.mkdir(parents=True, exist_ok=True) - - use_ai_bg = bool(getattr(state, "use_ai_edit", False)) - log.critical(f"[pdf2ppt 是否使用AI: ][page#{page_idx+1}] use_ai_bg={use_ai_bg}") - - ai_task = None - if use_ai_bg and os.path.exists(img_path): - try: - # A. 生成 Mask (黑底白框) - ori_cv = cv2.imread(img_path) - if ori_cv is not None: - h_cv, w_cv = ori_cv.shape[:2] - mask_cv = np.zeros((h_cv, w_cv), dtype=np.uint8) # 黑底 - - # 绘制 OCR 区域 (白框) - for line in final_ocr_lines: - bbox = line[0] - pad = 5 - mx1 = int(max(0, bbox[0] - pad)) - my1 = int(max(0, bbox[1] - pad)) - mx2 = int(min(w_cv, bbox[2] + pad)) - my2 = int(min(h_cv, bbox[3] + pad)) - cv2.rectangle(mask_cv, (mx1, my1), (mx2, my2), (255), -1) - - mask_path = base_dir / "masks" / f"mask_{page_idx+1:03d}.png" - mask_path.parent.mkdir(parents=True, exist_ok=True) - cv2.imwrite(str(mask_path), mask_cv) - - # B. 准备 AI 调用闭包 - req_cfg = getattr(state, "request", None) or {} - if not isinstance(req_cfg, dict): - req_cfg = req_cfg.__dict__ if hasattr(req_cfg, "__dict__") else {} - - api_key = req_cfg.get("api_key") or os.getenv("DF_API_KEY") - api_url = req_cfg.get("chat_api_url") or "https://api.apiyi.com" - model_name = req_cfg.get("gen_fig_model") or "gemini-3-pro-image-preview" - - if api_key: - log.info(f"[pdf2ppt_with_sam][page#{page_idx+1}] Scheduling Gemini Inpainting...") - prompt = ( - "Use the second image as a mask to remove text from the first image. " - "Fill the removed text areas with background texture to make it clean. " - "Keep non-text areas (figures, tables) unchanged." - ) - - async def _run_ai_job(_p_idx=page_idx, _img_p=img_path, _mask_p=str(mask_path), _out_p=str(clean_bg_path)): - await _call_image_api_with_retry( - lambda: gemini_multi_image_edit_async( - prompt=prompt, - image_paths=[_img_p, _mask_p], - save_path=_out_p, - api_url=api_url, - api_key=api_key, - model=model_name, - resolution="1K", - timeout=300 - ) - ) - - ai_task = _run_ai_job() - ai_coroutines.append(ai_task) - else: - log.warning("Skipping AI edit: No API Key provided") - except Exception as e: - log.error(f"[pdf2ppt_with_sam][page#{page_idx+1}] Prepare AI task failed: {e}") - - # 保存所有需要在渲染阶段使用的数据 - pages_render_data.append({ - "page_idx": page_idx, - "scale_x": slide_w_emu / w0, - "scale_y": slide_h_emu / h0, - "clean_bg_path": str(clean_bg_path), - "image_zones": image_zones, - "final_sam_items": final_sam_items, - "final_ocr_lines": final_ocr_lines, - "ai_task": ai_task # 用于追踪哪个页面发起了 AI 请求 - }) - - # ========================================================== - # Phase 2: 并发执行 AI 任务 & 字号聚类 - # ========================================================== - - # 2.1 字号聚类逻辑 - use_global_clustering = getattr(state, "use_global_font_clustering", False) - global_clusterer = None - - if use_global_clustering: - log.info("[pdf2ppt_with_sam] Performing GLOBAL font size clustering...") - all_sizes = [] - for p_data in pages_render_data: - for line in p_data["final_ocr_lines"]: - # line: (bbox, text, conf, type, raw_pt) - raw_pt = line[4] - if raw_pt and raw_pt > 0: - all_sizes.append(raw_pt) - - global_clusterer = ppt_tool.FontSizeClustering(n_clusters=3) - global_clusterer.fit(all_sizes) - - # 2.2 执行 AI 任务 - if ai_coroutines: - log.info(f"[pdf2ppt_with_sam] Executing {len(ai_coroutines)} AI background tasks in parallel...") - start_t = __import__("time").time() - # 忽略异常,确保后续 PPT 渲染能继续(失败的会降级为白底) - await asyncio.gather(*ai_coroutines, return_exceptions=True) - cost = __import__("time").time() - start_t - log.info(f"[pdf2ppt_with_sam] AI tasks finished. cost={cost:.2f}s") - - # ========================================================== - # Phase 3: 生成 PPT 页面 (组装) - # ========================================================== - for p_data in pages_render_data: - # 取出数据 - scale_x = p_data["scale_x"] - scale_y = p_data["scale_y"] - clean_bg_path = p_data["clean_bg_path"] - image_zones = p_data["image_zones"] - final_sam_items = p_data["final_sam_items"] - final_ocr_lines = p_data["final_ocr_lines"] - - # 准备当页的字号聚类器 - if use_global_clustering: - clusterer = global_clusterer - else: - # 单页聚类模式 - page_sizes = [l[4] for l in final_ocr_lines if l[4] > 0] - clusterer = ppt_tool.FontSizeClustering(n_clusters=3) - clusterer.fit(page_sizes) - - slide = prs.slides.add_slide(prs.slide_layouts[6]) - - # 3.1 设置背景 - bg_image_path_for_ppt = None - if os.path.exists(clean_bg_path): - bg_image_path_for_ppt = clean_bg_path - - if bg_image_path_for_ppt: - try: - slide.shapes.add_picture(bg_image_path_for_ppt, 0, 0, prs.slide_width, prs.slide_height) - except Exception as e: - log.error(f"Failed to set slide background image: {e}") - # 降级 - bg = slide.background - fill = bg.fill - fill.solid() - fill.fore_color.rgb = RGBColor(255, 255, 255) - else: - bg = slide.background - fill = bg.fill - fill.solid() - fill.fore_color.rgb = RGBColor(255, 255, 255) - - # 3.2 渲染 MinerU Image Zones - for zone in image_zones: - ipath = zone["img_path"] - if not os.path.exists(ipath): - log.warning(f"MinerU image path not found: {ipath}") - continue - - bbox = zone["bbox"] - left = ppt_tool.px_to_emu(bbox[0], scale_x) - top = ppt_tool.px_to_emu(bbox[1], scale_y) - width = ppt_tool.px_to_emu(bbox[2] - bbox[0], scale_x) - height = ppt_tool.px_to_emu(bbox[3] - bbox[1], scale_y) - - try: - slide.shapes.add_picture(ipath, left, top, width, height) - except Exception as e: - log.error(f"Failed to add mineru image: {e}") - - # 3.3 渲染 SAM Icons - for item in final_sam_items: - ipath = item.get("fg_png_path") or item.get("png_path") - if not ipath or not os.path.exists(ipath): continue - - bbox = item.get("bbox_px") - left = ppt_tool.px_to_emu(bbox[0], scale_x) - top = ppt_tool.px_to_emu(bbox[1], scale_y) - width = ppt_tool.px_to_emu(bbox[2] - bbox[0], scale_x) - height = ppt_tool.px_to_emu(bbox[3] - bbox[1], scale_y) - - try: - slide.shapes.add_picture(ipath, left, top, width, height) - except Exception as e: - log.error(f"Failed to add SAM icon: {e}") - - # 3.4 渲染 OCR Text - for line in final_ocr_lines: - bbox, text, conf, l_type, raw_pt = line - x1, y1, x2, y2 = bbox - if (x2 - x1) < 5 or (y2 - y1) < 5: continue - - left = ppt_tool.px_to_emu(x1, scale_x) - top = ppt_tool.px_to_emu(y1, scale_y) - width = max(1, ppt_tool.px_to_emu(x2 - x1, scale_x)) - height = max(1, ppt_tool.px_to_emu(y2 - y1, scale_y)) - - tb = slide.shapes.add_textbox(left, top, width, height) - tf = tb.text_frame - tf.clear() - tf.word_wrap = True - tb.fill.background() - tb.line.fill.background() - - p = tf.paragraphs[0] - p.text = text - - # 应用字号映射 - final_pt = clusterer.map(raw_pt) - p.font.size = Pt(final_pt) - - # MinerU 的 Title 标签只用于加粗,不再强制改变字号 - if l_type == "title": - p.font.bold = True - - p.font.color.rgb = RGBColor(0, 0, 0) - - # Save - # base_dir 已在函数开头定义 - ppt_path = base_dir / "pdf2ppt_with_sam_output.pptx" - prs.save(str(ppt_path)) - state.ppt_path = str(ppt_path) - log.info(f"[pdf2ppt_with_sam] PPT generated: {ppt_path}") - - return state - - nodes = { - "_start_": _init_result_path, - "pdf_to_images": pdf_to_images_node, - "parallel_processing": parallel_processing_node, # 新增:并行处理节点 - "slides_ppt_generation": slides_ppt_generation_node, - "_end_": lambda state: state, - } - - edges = [ - ("pdf_to_images", "parallel_processing"), # pdf_to_images 后进入并行处理 - ("parallel_processing", "slides_ppt_generation"), # 并行完成后汇合到 PPT 生成 - ("slides_ppt_generation", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges) - builder.add_edge("_start_", "pdf_to_images") - return builder diff --git a/dataflow_agent/workflow/wf_pdf2ppt_qwenvl.py b/dataflow_agent/workflow/wf_pdf2ppt_qwenvl.py deleted file mode 100644 index 8bfac2b..0000000 --- a/dataflow_agent/workflow/wf_pdf2ppt_qwenvl.py +++ /dev/null @@ -1,842 +0,0 @@ -""" -pdf2ppt_qwenvl workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -基于 slides PDF,融合 VLM (Qwen-VL-OCR) 替代传统 PaddleOCR: -1. 将 PDF 每页渲染为 PNG -2. 对每页图片用 VLM (ImageTextBBoxAgent) 做文字识别与定位 (替代 PaddleOCR) -3. 对每页图片用 MinerU 做版面分析(区分 Text vs Image/Table) -4. 对每页图片用 SAM 做图标 / 图块分割 -5. AI 背景编辑: - - 基于 VLM 提取的 bbox 生成 mask (或利用 VLM 调试阶段的 no_text 图) - - 调用 Inpainting API:填补 mask 之后的白色区域,结合背景颜色做 Inpainting -6. 智能合并与 PPT 生成: - - 结合 MinerU (版面), SAM (图标), VLM (文字) 结果生成 PPT。 -""" - -from __future__ import annotations -import os -import asyncio -from pathlib import Path -from typing import List, Dict, Any, Optional -from collections import Counter -import copy - -import cv2 -import numpy as np -import fitz # PyMuPDF -import yaml -from PIL import Image - -from dataflow_agent.workflow.registry import register -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.utils import get_project_root -from dataflow_agent.agentroles import create_vlm_agent - -# Tools -from dataflow_agent.toolkits.multimodaltool.sam_tool import segment_layout_boxes, segment_layout_boxes_server, free_sam_model -from dataflow_agent.toolkits.multimodaltool.bg_tool import local_tool_for_bg_remove, free_bg_rm_model -from dataflow_agent.toolkits.multimodaltool.mineru_tool import recursive_mineru_layout -from dataflow_agent.toolkits.multimodaltool.req_img import generate_or_edit_and_save_image_async -from dataflow_agent.toolkits.multimodaltool import ppt_tool - -from pptx import Presentation -from pptx.util import Inches, Pt -from pptx.dml.color import RGBColor - -log = get_logger(__name__) - -# Load configuration from yaml -def load_server_config(): - root = get_project_root() - config_path = root / "conf" / "model_servers.yaml" - if not config_path.exists(): - log.warning(f"Config file not found at {config_path}, using defaults.") - return {} - try: - with open(config_path, "r") as f: - return yaml.safe_load(f) or {} - except Exception as e: - log.error(f"Failed to load config: {e}") - return {} - -SERVER_CONFIG = load_server_config() - -def get_closest_aspect_ratio(w: int, h: int) -> str: - """ - 计算最接近的合法 Gemini 比例 - """ - valid_ratios = ['1:1', '2:3', '3:2', '3:4', '4:3', '4:5', '5:4', '9:16', '16:9', '21:9'] - target_ratio = w / h - - best_ratio = '16:9' # default - min_diff = float('inf') - - for r_str in valid_ratios: - rw, rh = map(int, r_str.split(':')) - curr_ratio = rw / rh - diff = abs(target_ratio - curr_ratio) - if diff < min_diff: - min_diff = diff - best_ratio = r_str - - return best_ratio - -# Helper to construct URLs -def get_sam_urls(): - if os.environ.get("SAM_SERVER_URLS"): - return os.environ.get("SAM_SERVER_URLS").split(",") - sam_cfg = SERVER_CONFIG.get("sam", {}) - instances = sam_cfg.get("instances", []) - if instances: - urls = [] - for inst in instances: - for port in inst.get("ports", []): - urls.append(f"http://127.0.0.1:{port}") - if urls: - return urls - return ["http://localhost:8021", "http://localhost:8022","http://localhost:8023"] - -SAM_SERVER_URLS = get_sam_urls() - -def _ensure_result_path(state: Paper2FigureState) -> str: - raw = getattr(state, "result_path", None) - if raw: - return raw - root = get_project_root() - ts = int(__import__("time").time()) - base_dir = (root / "outputs" / "pdf2ppt_qwenvl" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - -def _process_single_sam_page(page_idx: int, img_path: str, base_dir: str) -> Dict[str, Any]: - sam_ckpt = f"{get_project_root()}/sam_b.pt" - log.info(f"[pdf2ppt_qwenvl][SAM] processing page#{page_idx+1}: {img_path}") - img_path_obj = Path(img_path) - if not img_path_obj.exists(): - log.warning(f"[pdf2ppt_qwenvl][SAM] page#{page_idx+1} image not found") - return {"page_idx": page_idx, "layout_items": []} - - out_dir = Path(base_dir) / "layout_items" / f"page_{page_idx+1:03d}" - out_dir.mkdir(parents=True, exist_ok=True) - - layout_items = [] - try: - # 尝试远程调用,增加显式日志 - log.info(f"[pdf2ppt_qwenvl][SAM] page#{page_idx+1} calling segment_layout_boxes_server urls={SAM_SERVER_URLS}") - layout_items = segment_layout_boxes_server( - image_path=str(img_path_obj), - output_dir=str(out_dir), - server_urls=SAM_SERVER_URLS, - checkpoint=sam_ckpt, - min_area=200, - min_score=0.0, - iou_threshold=0.4, - top_k=15, - nms_by="mask", - ) - log.info(f"[pdf2ppt_qwenvl][SAM] page#{page_idx+1} server returned {len(layout_items)} items") - except Exception as e: - log.error(f"[pdf2ppt_qwenvl][SAM] page#{page_idx+1} Remote SAM failed: {e}. Fallback to local.") - try: - layout_items = segment_layout_boxes( - image_path=str(img_path_obj), - output_dir=str(out_dir), - checkpoint=sam_ckpt, - min_area=200, - min_score=0.0, - iou_threshold=0.4, - top_k=15, - nms_by="mask", - ) - log.info(f"[pdf2ppt_qwenvl][SAM] page#{page_idx+1} local SAM returned {len(layout_items)} items") - except Exception as e_local: - log.error(f"[pdf2ppt_qwenvl][SAM] page#{page_idx+1} Local SAM failed: {e_local}") - layout_items = [] - - try: - pil_img = Image.open(str(img_path_obj)) - w, h = pil_img.size - except Exception: - w, h = 1024, 768 - - for it in layout_items: - bbox = it.get("bbox") - if bbox and len(bbox) == 4: - x1n, y1n, x2n, y2n = bbox - x1 = int(round(x1n * w)) - y1 = int(round(y1n * h)) - x2 = int(round(x2n * w)) - y2 = int(round(y2n * h)) - if x2 > x1 and y2 > y1: - it["bbox_px"] = [x1, y1, x2, y2] - - return {"page_idx": page_idx, "layout_items": layout_items} - -def _run_sam_on_pages(image_paths: List[str], base_dir: str) -> List[Dict[str, Any]]: - from concurrent.futures import ThreadPoolExecutor, as_completed - - results: List[Dict[str, Any]] = [] - # 限制并发数 - max_workers = min(len(image_paths), 6) - - log.info(f"[pdf2ppt_qwenvl][SAM] starting parallel processing with {max_workers} workers") - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_idx = { - executor.submit(_process_single_sam_page, idx, path, base_dir): idx - for idx, path in enumerate(image_paths) - } - - for future in as_completed(future_to_idx): - try: - res = future.result() - results.append(res) - except Exception as e: - idx = future_to_idx[future] - log.error(f"[pdf2ppt_qwenvl][SAM] page#{idx+1} task exception: {e}") - results.append({"page_idx": idx, "layout_items": []}) - - # 清理本地可能加载的模型 - sam_ckpt = f"{get_project_root()}/sam_b.pt" - try: - free_sam_model(checkpoint=sam_ckpt) - except Exception: - pass - - return sorted(results, key=lambda x: x["page_idx"]) - -@register("pdf2ppt_qwenvl") -def create_pdf2ppt_qwenvl_graph() -> GenericGraphBuilder: - """ - Workflow factory: dfa run --wf pdf2ppt_qwenvl - """ - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - def _init_result_path(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - return state - - async def pdf_to_images_node(state: Paper2FigureState) -> Paper2FigureState: - if state.request.input_type == "FIGURE": - img_path = state.request.input_content - log.info(f"[pdf2ppt_qwenvl] FIGURE mode: using input image {img_path}") - state.use_ai_edit = True - if img_path and os.path.exists(img_path): - state.slide_images = [img_path] - else: - log.error(f"[pdf2ppt_qwenvl] FIGURE mode: image not found {img_path}") - return state - - pdf_path = getattr(state, "pdf_file", None) - if not pdf_path: - log.error("[pdf2ppt_qwenvl] state.pdf_file is empty") - return state - - base_dir = Path(_ensure_result_path(state)) - img_dir = base_dir / "slides_png" - image_paths = ppt_tool.pdf_to_images(pdf_path, str(img_dir)) - state.slide_images = image_paths - return state - - # --- 新增:VLM 节点 (替代原 OCR) --- - async def vlm_recognition_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 使用 VLM (ImageTextBBoxAgent) 提取文本和 bbox。 - 结果写入 state.vlm_pages。 - """ - image_paths: List[str] = getattr(state, "slide_images", []) or [] - if not image_paths: - log.warning("[pdf2ppt_qwenvl][vlm] no slide_images") - state.vlm_pages = [] - return state - - async def _process_single_image(page_idx: int, img_path: str) -> Dict[str, Any]: - try: - # 显式传递 result_path,确保 agent 内部能访问 - temp_state = copy.copy(state) - temp_state.result_path = state.result_path - - # Retry loop for VLM execution - max_retries = 3 - bbox_res = [] - - for attempt in range(max_retries): - try: - agent = create_vlm_agent( - name="ImageTextBBoxAgent", - model_name=getattr(state.request, "vlm_model", "qwen-vl-ocr-2025-11-20"), - chat_api_url=getattr(state.request, "chat_api_url", None), - vlm_mode="ocr", - additional_params={ - "input_image": img_path - } - ) - - log.info(f"[pdf2ppt_qwenvl][VLM] page#{page_idx+1} attempt {attempt+1}/{max_retries}...") - new_state = await agent.execute(temp_state) - bbox_res = getattr(new_state, "bbox_result", []) - - # Basic validation: ensure we got a list and it's not empty (unless image is truly blank) - # Here we assume a successful parse returns a list. If it failed to parse, base_agent usually returns error dict or empty. - # We can check if new_state has error info or if bbox_res is valid. - if isinstance(bbox_res, list): - # Success - break - else: - log.warning(f"[pdf2ppt_qwenvl][VLM] page#{page_idx+1} attempt {attempt+1} got invalid result: {type(bbox_res)}") - - except Exception as e: - log.warning(f"[pdf2ppt_qwenvl][VLM] page#{page_idx+1} attempt {attempt+1} failed: {e}") - if attempt == max_retries - 1: - raise e - # Continue to retry - await asyncio.sleep(1) - - if not isinstance(bbox_res, list): - bbox_res = [] - - # 修正 bbox 归一化 (0-1000 -> 0-1) - # 并生成 "no_text" mask 图,供后续 Inpainting 使用 - processed_items = [] - - # 读取原图尺寸 - pil_img = Image.open(img_path) - w, h = pil_img.size - - # 准备生成 no_text 图 - mask_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) - - VLM_SCALE = 1000.0 - - for it in bbox_res: - # 处理 rotate_rect - if "rotate_rect" in it and "bbox" not in it: - try: - rr = it["rotate_rect"] - if isinstance(rr, list) and len(rr) == 5: - cx, cy, rw, rh, angle = rr - rect = ((float(cx), float(cy)), (float(rw), float(rh)), float(angle)) - box = cv2.boxPoints(rect) - x_min = np.min(box[:, 0]) - x_max = np.max(box[:, 0]) - y_min = np.min(box[:, 1]) - y_max = np.max(box[:, 1]) - - it["bbox"] = [ - max(0.0, min(1.0, y_min / VLM_SCALE)), - max(0.0, min(1.0, x_min / VLM_SCALE)), - max(0.0, min(1.0, y_max / VLM_SCALE)), - max(0.0, min(1.0, x_max / VLM_SCALE)) - ] - except Exception: - pass - - if "bbox" in it: - processed_items.append(it) - # 在 mask_img 上将文字区域涂白 - y1_n, x1_n, y2_n, x2_n = it["bbox"] - x1 = int(x1_n * w) - y1 = int(y1_n * h) - x2 = int(x2_n * w) - y2 = int(y2_n * h) - # 稍微扩大一点 mask 区域以覆盖完全 - pad = 2 - x1 = max(0, x1 - pad) - y1 = max(0, y1 - pad) - x2 = min(w, x2 + pad) - y2 = min(h, y2 + pad) - - cv2.rectangle(mask_img, (x1, y1), (x2, y2), (255, 255, 255), -1) - - # 保存 no_text 图片 - base_dir = Path(_ensure_result_path(state)) - debug_dir = base_dir / "vlm_debug" - debug_dir.mkdir(parents=True, exist_ok=True) - no_text_path = debug_dir / f"page_{page_idx+1:03d}_no_text.png" - cv2.imwrite(str(no_text_path), mask_img) - - log.info(f"[pdf2ppt_qwenvl][VLM] page#{page_idx+1} items={len(processed_items)}, saved mask to {no_text_path}") - - return { - "page_idx": page_idx, - "path": img_path, - "vlm_data": processed_items, - "no_text_path": str(no_text_path) - } - - except Exception as e: - log.error(f"[pdf2ppt_qwenvl][VLM] page#{page_idx+1} failed: {e}") - return { - "page_idx": page_idx, - "path": img_path, - "vlm_data": [], - "error": str(e) - } - - tasks = [_process_single_image(i, p) for i, p in enumerate(image_paths)] - results = await asyncio.gather(*tasks) - state.vlm_pages = results - return state - - async def slides_mineru_node(state: Paper2FigureState) -> Paper2FigureState: - """MinerU 版面分析 (并行优化)""" - image_paths: List[str] = getattr(state, "slide_images", []) or [] - log.info(f"[pdf2ppt_qwenvl][MinerU] start, images={len(image_paths)}") - if not image_paths: - return state - - base_dir = Path(_ensure_result_path(state)) - mineru_dir = base_dir / "mineru_pages" - mineru_dir.mkdir(parents=True, exist_ok=True) - port = getattr(getattr(state, "request", None), "mineru_port", 8010) - - async def _process_mineru_page(page_idx: int, img_path: str) -> Dict[str, Any]: - try: - out_dir = mineru_dir / f"page_{page_idx+1:03d}" - out_dir.mkdir(parents=True, exist_ok=True) - log.info(f"[pdf2ppt_qwenvl][MinerU] page#{page_idx+1} calling recursive_mineru_layout, port={port}") - - # 加超时保护,避免挂死 - try: - mineru_items = await asyncio.wait_for( - recursive_mineru_layout( - image_path=str(img_path), - port=port, - max_depth=3, - output_dir=str(out_dir), - ), - timeout=120.0, - ) - except asyncio.TimeoutError: - log.error(f"[pdf2ppt_qwenvl][MinerU] page#{page_idx+1} MinerU timeout (>120s)") - mineru_items = [] - except Exception as inner_e: - log.error(f"[pdf2ppt_qwenvl][MinerU] page#{page_idx+1} MinerU exception: {inner_e}") - mineru_items = [] - - log.info(f"[pdf2ppt_qwenvl][MinerU] page#{page_idx+1} got {len(mineru_items)} blocks") - return { - "page_idx": page_idx, - "blocks": mineru_items, - "path": img_path, - "mineru_output_dir": str(out_dir) - } - except Exception as e: - log.error(f"[pdf2ppt_qwenvl][MinerU] page#{page_idx+1} failed: {e}") - return { - "page_idx": page_idx, - "blocks": [], - "path": img_path - } - - tasks = [_process_mineru_page(i, p) for i, p in enumerate(image_paths)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 过滤异常并打印 - cleaned_results: List[Dict[str, Any]] = [] - for r in results: - if isinstance(r, Exception): - log.error(f"[pdf2ppt_qwenvl][MinerU] task exception: {r}") - continue - cleaned_results.append(r) - - # 按 page_idx 排序确保顺序 - state.mineru_pages = sorted(cleaned_results, key=lambda x: x["page_idx"]) - log.info(f"[pdf2ppt_qwenvl][MinerU] done, pages={len(state.mineru_pages)}") - return state - - async def slides_sam_node(state: Paper2FigureState) -> Paper2FigureState: - """SAM 图标分割""" - image_paths: List[str] = getattr(state, "slide_images", []) or [] - log.info(f"[pdf2ppt_qwenvl][SAM] start, images={len(image_paths)}") - if not image_paths: - return state - base_dir = _ensure_result_path(state) - try: - sam_pages = await asyncio.to_thread(_run_sam_on_pages, image_paths, base_dir) - log.info(f"[pdf2ppt_qwenvl][SAM] done, pages={len(sam_pages)}") - except Exception as e: - log.error(f"[pdf2ppt_qwenvl][SAM] SAM processing failed: {e}") - sam_pages = [] - state.sam_pages = sam_pages - return state - - async def slides_layout_bg_remove_node(state: Paper2FigureState, sam_pages: List[Dict[str, Any]] = None) -> Paper2FigureState: - """SAM 结果背景移除""" - if sam_pages is None: - sam_pages = getattr(state, "sam_pages", []) or [] - if not sam_pages: - return state - - base_dir = Path(_ensure_result_path(state)) - icons_dir = base_dir / "sam_icons" - icons_dir.mkdir(parents=True, exist_ok=True) - model_path = getattr(getattr(state, "request", None), "bg_rm_model", None) - - def _sync_bg_remove(): - processed = 0 - for p in sam_pages: - page_idx = p.get("page_idx", 0) - for it in p.get("layout_items", []): - png_path = it.get("png_path") - if not png_path or not os.path.exists(png_path): continue - - try: - original_stem = Path(png_path).stem - output_filename = f"page_{page_idx+1:03d}_{original_stem}_bg_removed.png" - output_path = icons_dir / output_filename - - req = {"image_path": png_path, "output_dir": str(icons_dir)} - if model_path: req["model_path"] = model_path - - fg_path = local_tool_for_bg_remove(req) - - if fg_path and os.path.exists(fg_path): - fg_path_obj = Path(fg_path) - if fg_path_obj.name != output_filename: - new_fg_path = fg_path_obj.parent / output_filename - fg_path_obj.rename(new_fg_path) - fg_path = str(new_fg_path) - it["fg_png_path"] = fg_path - else: - it["fg_png_path"] = png_path - processed += 1 - except Exception: - it["fg_png_path"] = png_path - - try: - if model_path: free_bg_rm_model(model_path=model_path) - except Exception: pass - return processed - - await asyncio.to_thread(_sync_bg_remove) - state.sam_pages = sam_pages - return state - - async def slides_inpainting_node(state: Paper2FigureState) -> Paper2FigureState: - """AI Inpainting: 填补文字 mask 区域""" - vlm_pages = getattr(state, "vlm_pages", []) or [] - if not vlm_pages: - return state - - base_dir = Path(_ensure_result_path(state)) - - # API 配置 - req_cfg = getattr(state, "request", None) or {} - if not isinstance(req_cfg, dict): req_cfg = req_cfg.__dict__ if hasattr(req_cfg, "__dict__") else {} - api_key = req_cfg.get("api_key") or os.getenv("DF_API_KEY") - api_url = req_cfg.get("chat_api_url") or "https://api.apiyi.com" - model_name = req_cfg.get("gen_fig_model") or "gemini-3-pro-image-preview" - - # 限制并发 - sem = asyncio.Semaphore(3) - - async def _call_image_api_with_retry(coro_factory, retries=3): - for i in range(retries): - try: - await coro_factory() - return True - except Exception as e: - if i == retries - 1: log.error(f"Image API failed: {e}") - await asyncio.sleep(1) - return False - - async def _process_inpainting(pinfo): - page_idx = pinfo.get("page_idx", 0) - no_text_path = pinfo.get("no_text_path") - - clean_bg_path = base_dir / "clean_bg" / f"bg_{page_idx+1:03d}.png" - clean_bg_path.parent.mkdir(parents=True, exist_ok=True) - - # 将结果路径回写到 pinfo,供后续步骤使用 - pinfo["clean_bg_path"] = str(clean_bg_path) - - if state.use_ai_edit and api_key and no_text_path and os.path.exists(no_text_path): - ratio_str = "16:9" - try: - with Image.open(no_text_path) as tmp_img: - ratio_str = get_closest_aspect_ratio(tmp_img.width, tmp_img.height) - except Exception: pass - - inpainting_prompt = "使用背景颜色,填充图里被mask的白色部分,去掉全部文字!" - - async with sem: - await _call_image_api_with_retry( - lambda: generate_or_edit_and_save_image_async( - prompt=inpainting_prompt, - save_path=str(clean_bg_path), - api_url=api_url, - api_key=api_key, - model=model_name, - use_edit=True, - image_path=no_text_path, - aspect_ratio=ratio_str, - resolution="2K" - ) - ) - else: - # 降级:复制 no_text 图 - if no_text_path and os.path.exists(no_text_path): - import shutil - try: shutil.copy(no_text_path, clean_bg_path) - except: pass - - tasks = [_process_inpainting(p) for p in vlm_pages] - if tasks: - log.info(f"[pdf2ppt_qwenvl][Inpainting] starting {len(tasks)} tasks in parallel with VLM flow") - await asyncio.gather(*tasks) - - return state - - # --- 并行处理节点 --- - async def parallel_processing_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 并行执行: - 1. VLM (OCR) -> Inpainting - 2. MinerU - 3. SAM + BgRemove - """ - import copy - import time - start_time = time.time() - - async def vlm_and_inpainting_branch(): - branch_state = copy.copy(state) - branch_state = await vlm_recognition_node(branch_state) - branch_state = await slides_inpainting_node(branch_state) - return ("vlm", branch_state) - - async def mineru_branch(): - branch_state = copy.copy(state) - result = await slides_mineru_node(branch_state) - return ("mineru", result) - - async def sam_branch(): - branch_state = copy.copy(state) - branch_state = await slides_sam_node(branch_state) - sam_pages = getattr(branch_state, "sam_pages", []) - branch_state = await slides_layout_bg_remove_node(branch_state, sam_pages=sam_pages) - return ("sam", branch_state) - - results = await asyncio.gather( - vlm_and_inpainting_branch(), - mineru_branch(), - sam_branch(), - return_exceptions=True - ) - - log.info(f"[pdf2ppt_qwenvl] parallel branches returned: {results}") - - for r in results: - if isinstance(r, Exception): - log.error(f"[pdf2ppt_qwenvl] Branch failed: {r}") - continue - branch_name, branch_state = r - if branch_name == "vlm": - state.vlm_pages = getattr(branch_state, "vlm_pages", []) - elif branch_name == "mineru": - state.mineru_pages = getattr(branch_state, "mineru_pages", []) - elif branch_name == "sam": - state.sam_pages = getattr(branch_state, "sam_pages", []) - - log.info(f"[pdf2ppt_qwenvl] Parallel processing finished in {time.time() - start_time:.2f}s") - return state - - async def slides_ppt_generation_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 生成 PPT: - 1. 整合 VLM, MinerU, SAM 结果 - 2. 渲染页面 (Inpainting 已经在上一步并行完成) - """ - vlm_pages = getattr(state, "vlm_pages", []) or [] - sam_pages = getattr(state, "sam_pages", []) or [] - mineru_pages = getattr(state, "mineru_pages", []) or [] - - if not vlm_pages: - log.error("[pdf2ppt_qwenvl] no vlm_pages, abort PPT generation") - return state - - # Indexing - sam_dict = {p.get("page_idx", 0): p.get("layout_items", []) for p in sam_pages} - mineru_dict = {} - for p in mineru_pages: - mineru_dict[p.get("page_idx", 0)] = p - - prs = Presentation() - prs.slide_width = Inches(ppt_tool.SLIDE_W_IN) - prs.slide_height = Inches(ppt_tool.SLIDE_H_IN) - slide_w_emu = prs.slide_width - slide_h_emu = prs.slide_height - - base_dir = Path(_ensure_result_path(state)) - - # 辅助几何函数 - def _bbox_area(bbox): return max(0, bbox[2]-bbox[0]) * max(0, bbox[3]-bbox[1]) - def _get_intersection_area(b1, b2): - x1,y1,x2,y2 = max(b1[0],b2[0]), max(b1[1],b2[1]), min(b1[2],b2[2]), min(b1[3],b2[3]) - return max(0, x2-x1) * max(0, y2-y1) - def _is_inside(inner, outer, th=0.9): - ia = _bbox_area(inner) - return (ia > 0) and ((_get_intersection_area(inner, outer) / ia) >= th) - - for pinfo in vlm_pages: - page_idx = pinfo.get("page_idx", 0) - img_path = pinfo.get("path") - vlm_data = pinfo.get("vlm_data", []) - # 从 vlm_pages 里获取 clean_bg_path,如果并行步骤成功,这里应该有值 - clean_bg_path = pinfo.get("clean_bg_path") - - if not img_path or not os.path.exists(img_path): continue - - try: - pil_img = Image.open(img_path) - w0, h0 = pil_img.size - except Exception: continue - - scale_x = slide_w_emu / w0 - scale_y = slide_h_emu / h0 - - # 1. MinerU Image Zones - mineru_data = mineru_dict.get(page_idx, {}) - mineru_blocks = mineru_data.get("blocks", []) - image_zones = [] - - sub_images_dir = None - if mineru_data.get("mineru_output_dir"): - try: - # 简单尝试找 sub_images - possibles = list(Path(mineru_data["mineru_output_dir"]).rglob("sub_images")) - if possibles: sub_images_dir = possibles[0] - except Exception: pass - - for idx, blk in enumerate(mineru_blocks): - btype = (blk.get("type") or "").lower() - bbox = blk.get("bbox") # norm - if not bbox or len(bbox)!=4: continue - if btype in ['image', 'figure', 'table', 'formula']: - x1, y1, x2, y2 = int(bbox[0]*w0), int(bbox[1]*h0), int(bbox[2]*w0), int(bbox[3]*h0) - px_bbox = [x1, y1, x2, y2] - - img_path_found = None - if blk.get("img_path") and os.path.exists(blk["img_path"]): - img_path_found = blk["img_path"] - - if not img_path_found: - fb_dir = base_dir / "mineru_fallback" / f"p{page_idx}" - fb_dir.mkdir(parents=True, exist_ok=True) - save_p = fb_dir / f"blk_{idx}.png" - if not save_p.exists(): - try: pil_img.crop((x1,y1,x2,y2)).save(save_p) - except: pass - img_path_found = str(save_p) - - if img_path_found: - image_zones.append({"bbox": px_bbox, "type": btype, "img_path": img_path_found}) - - # 2. VLM Text Filtering (Filter text inside images) - final_text_lines = [] - for it in vlm_data: - # it: {'bbox': [y1n, x1n, y2n, x2n], 'text': ...} (0-1 norm) - bbox_n = it.get("bbox") - if not bbox_n: continue - y1n, x1n, y2n, x2n = bbox_n - x1, y1, x2, y2 = int(x1n*w0), int(y1n*h0), int(x2n*w0), int(y2n*h0) - l_bbox = [x1, y1, x2, y2] - - is_in_image = False - # for z in image_zones: - # if _is_inside(l_bbox, z["bbox"]): - # is_in_image = True - # break - - if not is_in_image: - # 估算字号 - raw_pt_obj = ppt_tool.estimate_font_pt(l_bbox, img_h_px=h0, body_h_px=None) - raw_pt = raw_pt_obj.pt if hasattr(raw_pt_obj, "pt") else raw_pt_obj - - # 简单判断 title (基于字号或位置,这里简化) - l_type = "body" - if raw_pt > 18: l_type = "title" # 简单阈值,可改进 - - final_text_lines.append((l_bbox, it.get("text",""), 1.0, l_type, raw_pt)) - - # 3. SAM Icons Filtering - raw_sam = sam_dict.get(page_idx, []) - final_sam = [] - for item in raw_sam: - s_bbox = item.get("bbox_px") - if not s_bbox: continue - # Filter if inside Image Zone - if any(_is_inside(s_bbox, z["bbox"], 0.6) for z in image_zones): continue - # Filter if overlaps with Text - if any(_is_inside(line[0], s_bbox) for line in final_text_lines): continue - - final_sam.append(item) - - # 渲染 PPT 页面 - slide = prs.slides.add_slide(prs.slide_layouts[6]) - - # Background - if clean_bg_path and os.path.exists(clean_bg_path): - try: slide.shapes.add_picture(clean_bg_path, 0, 0, prs.slide_width, prs.slide_height) - except: pass - - # MinerU Images - for z in image_zones: - if os.path.exists(z["img_path"]): - bx = z["bbox"] - slide.shapes.add_picture(z["img_path"], - ppt_tool.px_to_emu(bx[0], scale_x), ppt_tool.px_to_emu(bx[1], scale_y), - ppt_tool.px_to_emu(bx[2]-bx[0], scale_x), ppt_tool.px_to_emu(bx[3]-bx[1], scale_y)) - - # SAM Icons - for s in final_sam: - path = s.get("fg_png_path") or s.get("png_path") - if path and os.path.exists(path): - bx = s["bbox_px"] - slide.shapes.add_picture(path, - ppt_tool.px_to_emu(bx[0], scale_x), ppt_tool.px_to_emu(bx[1], scale_y), - ppt_tool.px_to_emu(bx[2]-bx[0], scale_x), ppt_tool.px_to_emu(bx[3]-bx[1], scale_y)) - - # Text - for line in final_text_lines: - bbox, text, _, l_type, raw_pt = line - left = ppt_tool.px_to_emu(bbox[0], scale_x) - top = ppt_tool.px_to_emu(bbox[1], scale_y) - w = ppt_tool.px_to_emu(bbox[2]-bbox[0], scale_x) - h = ppt_tool.px_to_emu(bbox[3]-bbox[1], scale_y) - - tb = slide.shapes.add_textbox(left, top, w, h) - p = tb.text_frame.paragraphs[0] - p.text = text - p.font.size = Pt(raw_pt if raw_pt > 5 else 12) - p.font.bold = (l_type == "title") - p.font.color.rgb = RGBColor(0,0,0) - - out_path = base_dir / "pdf2ppt_qwenvl_output.pptx" - prs.save(str(out_path)) - state.ppt_path = str(out_path) - log.info(f"[pdf2ppt_qwenvl] PPT Generated: {out_path}") - return state - - nodes = { - "_start_": _init_result_path, - "pdf_to_images": pdf_to_images_node, - "parallel_processing": parallel_processing_node, - "slides_ppt_generation": slides_ppt_generation_node, - "_end_": lambda s: s, - } - - edges = [ - ("pdf_to_images", "parallel_processing"), - ("parallel_processing", "slides_ppt_generation"), - ("slides_ppt_generation", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges) - builder.add_edge("_start_", "pdf_to_images") - return builder diff --git a/dataflow_agent/workflow/wf_pdf2ppt_with_sam_ocr_mineru.py b/dataflow_agent/workflow/wf_pdf2ppt_with_sam_ocr_mineru.py deleted file mode 100644 index f1cc2f8..0000000 --- a/dataflow_agent/workflow/wf_pdf2ppt_with_sam_ocr_mineru.py +++ /dev/null @@ -1,939 +0,0 @@ -""" -pdf2ppt_with_sam workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -基于 slides PDF: -1. 将 PDF 每页渲染为 PNG -2. 对每页图片用 PaddleOCR 做文字 OCR -3. 对每页图片用 MinerU 做版面分析(区分 Text vs Image/Table) -4. 对每页图片用 SAM 做图标 / 图块分割 -5. 智能合并: - - MinerU 划定 "图表区" (Image/Table) 和 "正文区"。 - - OCR 文本如果落在 "图表区" 则丢弃,防止图片上的文字重复生成。 - - SAM 图块如果落在 "图表区" 则丢弃(由 MinerU 负责);如果在 "正文区" 且包含文字则丢弃(防止把文字当图); - 剩下的 SAM 块被视为 "无字图标",进行抠图后保留。 - - MinerU 提取的图片直接复用其 sub_images 目录,不再手动裁剪。 - - 字体归一化:全局统计正文和标题字号,强制统一,保证整齐。 - - 使用 AI Inpainting 生成干净背景。 -""" - -from __future__ import annotations -import os -import asyncio -from pathlib import Path -from typing import List, Dict, Any, Optional -from collections import Counter - -import cv2 -import numpy as np -import fitz # PyMuPDF -import yaml -from PIL import Image - -from dataflow_agent.workflow.registry import register -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger - -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.utils import get_project_root - -# Tools -from dataflow_agent.toolkits.multimodaltool.sam_tool import segment_layout_boxes, segment_layout_boxes_server, free_sam_model -from dataflow_agent.toolkits.multimodaltool.bg_tool import local_tool_for_bg_remove, free_bg_rm_model -from dataflow_agent.toolkits.multimodaltool.mineru_tool import recursive_mineru_layout -from dataflow_agent.toolkits.multimodaltool.req_img import gemini_multi_image_edit_async -from dataflow_agent.toolkits.multimodaltool import ppt_tool - -from pptx import Presentation -from pptx.util import Inches, Pt -from pptx.dml.color import RGBColor - -log = get_logger(__name__) - -# Load configuration from yaml -def load_server_config(): - root = get_project_root() - config_path = root / "conf" / "model_servers.yaml" - if not config_path.exists(): - log.warning(f"Config file not found at {config_path}, using defaults.") - return {} - try: - with open(config_path, "r") as f: - return yaml.safe_load(f) or {} - except Exception as e: - log.error(f"Failed to load config: {e}") - return {} - -SERVER_CONFIG = load_server_config() - -# Helper to construct URLs -def get_sam_urls(): - # Check env var first - if os.environ.get("SAM_SERVER_URLS"): - return os.environ.get("SAM_SERVER_URLS").split(",") - - # Try config - sam_cfg = SERVER_CONFIG.get("sam", {}) - instances = sam_cfg.get("instances", []) - if instances: - urls = [] - for inst in instances: - for port in inst.get("ports", []): - urls.append(f"http://127.0.0.1:{port}") - if urls: - return urls - - # Default - return ["http://localhost:8021", "http://localhost:8022"] - -def get_ocr_urls(): - # Check env var first - if os.environ.get("OCR_SERVER_URLS"): - return os.environ.get("OCR_SERVER_URLS").split(",") - - # Try config - ocr_cfg = SERVER_CONFIG.get("ocr", {}) - if ocr_cfg: - host = ocr_cfg.get("host", "0.0.0.0") - if host == "0.0.0.0": host = "127.0.0.1" - port = ocr_cfg.get("port", 8003) - return [f"http://{host}:{port}"] - - # Default - return ["http://localhost:8003"] - -SAM_SERVER_URLS = get_sam_urls() -OCR_SERVER_URLS = get_ocr_urls() - - -def _ensure_result_path(state: Paper2FigureState) -> str: - """ - 为本次 pdf2ppt_with_sam workflow 创建统一的输出目录: - - 如果 state.result_path 已存在,直接使用; - - 否则使用项目根目录下 outputs/pdf2ppt_with_sam/。 - """ - raw = getattr(state, "result_path", None) - if raw: - return raw - - root = get_project_root() - ts = int(__import__("time").time()) - base_dir = (root / "outputs" / "pdf2ppt_with_sam" / str(ts)).resolve() - base_dir.mkdir(parents=True, exist_ok=True) - state.result_path = str(base_dir) - return state.result_path - - -def _run_sam_on_pages(image_paths: List[str], base_dir: str) -> List[Dict[str, Any]]: - """ - 对每一页图片运行 SAM,输出 layout_items。 - """ - results: List[Dict[str, Any]] = [] - sam_ckpt = f"{get_project_root()}/sam_b.pt" - - for page_idx, img_path in enumerate(image_paths): - img_path_obj = Path(img_path) - if not img_path_obj.exists(): - log.warning(f"[pdf2ppt_with_sam] image not found for SAM: {img_path}") - results.append({"page_idx": page_idx, "layout_items": []}) - continue - - out_dir = Path(base_dir) / "layout_items" / f"page_{page_idx+1:03d}" - out_dir.mkdir(parents=True, exist_ok=True) - - # 1. SAM 分割 (使用远程服务) - try: - layout_items = segment_layout_boxes_server( - image_path=str(img_path_obj), - output_dir=str(out_dir), - server_urls=SAM_SERVER_URLS, - checkpoint=sam_ckpt, - min_area=200, - min_score=0.0, - iou_threshold=0.4, - top_k=25, - nms_by="mask", - ) - except Exception as e: - log.error(f"[pdf2ppt_with_sam] Remote SAM failed: {e}. Fallback to local.") - # Fallback to local if server fails - layout_items = segment_layout_boxes( - image_path=str(img_path_obj), - output_dir=str(out_dir), - checkpoint=sam_ckpt, - min_area=200, - min_score=0.0, - iou_threshold=0.4, - top_k=25, - nms_by="mask", - ) - - log.info(f"[pdf2ppt_with_sam][page#{page_idx+1}] SAM found {len(layout_items)} items") - - # 2. 映射 bbox 到像素坐标(基于整页尺寸) - try: - pil_img = Image.open(str(img_path_obj)) - w, h = pil_img.size - except Exception as e: - log.error(f"[pdf2ppt_with_sam][page#{page_idx+1}] open image failed: {e}") - w, h = 1024, 768 - - for it in layout_items: - bbox = it.get("bbox") - if bbox and len(bbox) == 4: - x1n, y1n, x2n, y2n = bbox - x1 = int(round(x1n * w)) - y1 = int(round(y1n * h)) - x2 = int(round(x2n * w)) - y2 = int(round(y2n * h)) - if x2 > x1 and y2 > y1: - it["bbox_px"] = [x1, y1, x2, y2] - - results.append({"page_idx": page_idx, "layout_items": layout_items}) - - # 显式释放 SAM 模型 - try: - free_sam_model(checkpoint=sam_ckpt) - except Exception as e: - log.error(f"[pdf2ppt_with_sam] free_sam_model failed: {e}") - - return results - - -@register("pdf2ppt_with_sam_ocr_mineru") -def create_pdf2ppt_with_sam_graph() -> GenericGraphBuilder: # noqa: N802 - """ - Workflow factory: dfa run --wf pdf2ppt_with_sam - """ - builder = GenericGraphBuilder(state_model=Paper2FigureState, entry_point="_start_") - - # ============================== - # NODES - # ============================== - - def _init_result_path(state: Paper2FigureState) -> Paper2FigureState: - _ensure_result_path(state) - return state - - async def pdf_to_images_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 将 PDF 每一页渲染为 PNG。 - """ - pdf_path = getattr(state, "pdf_file", None) - if not pdf_path: - log.error("[pdf2ppt_with_sam] state.pdf_file is empty") - return state - - base_dir = Path(_ensure_result_path(state)) - img_dir = base_dir / "slides_png" - image_paths = ppt_tool.pdf_to_images(pdf_path, str(img_dir)) - state.slide_images = image_paths - return state - - async def slides_ocr_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 对每一页图片用 PaddleOCR 做 OCR。 - """ - image_paths: List[str] = getattr(state, "slide_images", []) or [] - if not image_paths: - log.error("[pdf2ppt_with_sam] no slide_images for OCR") - return state - - ocr_pages: List[Dict[str, Any]] = [] - for page_idx, img_path in enumerate(image_paths): - try: - # 优先使用远程 OCR 服务 - try: - result = ppt_tool.paddle_ocr_page_with_layout_server(img_path, server_urls=OCR_SERVER_URLS) - except Exception as e: - log.warning(f"[pdf2ppt_with_sam][OCR] remote failed: {e}. Fallback to local.") - result = ppt_tool.paddle_ocr_page_with_layout(img_path) - except Exception as e: - log.error(f"[pdf2ppt_with_sam][OCR] page#{page_idx+1} failed: {e}") - result = { - "image_size": None, - "lines": [], - "body_h_px": None, - "bg_color": None, - "path": img_path, - "page_idx": page_idx, - } - result["page_idx"] = page_idx - result["path"] = img_path - ocr_pages.append(result) - - state.ocr_pages = ocr_pages - return state - - async def slides_mineru_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 对每一页 PNG 使用 MinerU 做版面识别: - - 输出每页的 mineru_items,包含 type / bbox(norm) / text 等 - """ - image_paths: List[str] = getattr(state, "slide_images", []) or [] - if not image_paths: - log.error("[pdf2ppt_with_sam] no slide_images for MinerU") - return state - - base_dir = Path(_ensure_result_path(state)) - mineru_dir = base_dir / "mineru_pages" - mineru_dir.mkdir(parents=True, exist_ok=True) - - # MinerU 端口,优先从 state.request.mineru_port 读取 - # MinerU LB Port 8010 - port = getattr(getattr(state, "request", None), "mineru_port", 8010) - # 复杂度深度可从 state 或常量 - max_depth = getattr(state, "mask_detail_level", 3) - - mineru_pages: List[Dict[str, Any]] = [] - - for page_idx, img_path in enumerate(image_paths): - try: - out_dir = mineru_dir / f"page_{page_idx+1:03d}" - out_dir.mkdir(parents=True, exist_ok=True) - - log.critical(f"【mineru node】: {out_dir}") - - mineru_items = await recursive_mineru_layout( - image_path=str(img_path), - port=port, - max_depth=3, - output_dir=str(out_dir), - ) - - # 记录 MinerU 输出目录,方便后续找 sub_images - # recursive_mineru_layout 会在 out_dir 下直接输出或创建子目录 - # 这里我们记录 out_dir,后续可以在里面找 sub_images - - mineru_pages.append({ - "page_idx": page_idx, - "blocks": mineru_items, - "path": img_path, - "mineru_output_dir": str(out_dir) - }) - log.info(f"[pdf2ppt_with_sam][MinerU] page#{page_idx+1} got {len(mineru_items)} blocks") - except Exception as e: - log.error(f"[pdf2ppt_with_sam][MinerU] page#{page_idx+1} failed: {e}") - mineru_pages.append({ - "page_idx": page_idx, - "blocks": [], - "path": img_path, - }) - - state.mineru_pages = mineru_pages - - log.critical(f"[state.mineru_pages]: {state.mineru_pages}") - - return state - - async def slides_sam_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 对每一页图片运行 SAM 用于图标 / 图块分割。 - """ - image_paths: List[str] = getattr(state, "slide_images", []) or [] - if not image_paths: - log.error("[pdf2ppt_with_sam] no slide_images for SAM") - return state - - base_dir = _ensure_result_path(state) - sam_pages = _run_sam_on_pages(image_paths, base_dir) - state.sam_pages = sam_pages - return state - - async def slides_layout_bg_remove_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 对每一页 SAM layout PNG 做背景抠图: - - 输入: state.sam_pages[*].layout_items[].png_path - - 输出: 为每个 layout_item 写入 fg_png_path(抠完背景的 PNG) - """ - sam_pages: List[Dict[str, Any]] = getattr(state, "sam_pages", []) or [] - if not sam_pages: - log.error("[pdf2ppt_with_sam] no sam_pages for bg remove") - return state - - base_dir = Path(_ensure_result_path(state)) - icons_dir = base_dir / "sam_icons" - icons_dir.mkdir(parents=True, exist_ok=True) - - model_path = getattr(getattr(state, "request", None), "bg_rm_model", None) - - processed = 0 - - for p in sam_pages: - page_idx = p.get("page_idx", 0) - for it in p.get("layout_items", []): - png_path = it.get("png_path") - if not png_path or not os.path.exists(png_path): - continue - - # 背景抠图 - 添加页码前缀避免文件名冲突 - try: - # 从原始路径提取文件名 - original_stem = Path(png_path).stem - # 创建带页码的输出文件名 - output_filename = f"page_{page_idx+1:03d}_{original_stem}_bg_removed.png" - output_path = icons_dir / output_filename - - req = { - "image_path": png_path, - "output_dir": str(icons_dir), - } - if model_path: - req["model_path"] = model_path - - fg_path = local_tool_for_bg_remove(req) - - # 重命名文件以包含页码 - if fg_path and os.path.exists(fg_path): - # 将生成的文件重命名为带页码的文件名 - fg_path_obj = Path(fg_path) - if fg_path_obj.name != output_filename: - new_fg_path = fg_path_obj.parent / output_filename - fg_path_obj.rename(new_fg_path) - fg_path = str(new_fg_path) - - it["fg_png_path"] = fg_path - else: - it["fg_png_path"] = png_path - - processed += 1 - except Exception as e: - log.error(f"[pdf2ppt_with_sam][bg_rm] failed for {png_path}: {e}") - it["fg_png_path"] = png_path - - # 抠图完成后可尝试释放模型(忽略失败) - try: - if model_path: - free_bg_rm_model(model_path=model_path) - except Exception as e: - log.error(f"[pdf2ppt_with_sam] free_bg_rm_model failed: {e}") - - log.info(f"[pdf2ppt_with_sam] bg remove processed: {processed} items") - return state - - async def slides_ppt_generation_node(state: Paper2FigureState) -> Paper2FigureState: - """ - 结合 MinerU + OCR + SAM 结果生成可编辑 PPT: - - 改进点: - 1. MinerU 图片渲染修复:优先复用 MinerU 输出目录下的 sub_images,无法匹配时再手动裁剪。 - 2. 字体归一化: - - 统计全页正文(Body)文本的平均字号,取众数作为标准正文字号。 - - 标题(Title)字号设为标准正文的 1.5 倍(或取 Title 众数)。 - - 强制所有 Body 文本使用 Standard Body Font,所有 Title 文本使用 Standard Title Font。 - 3. 背景生成开关: - - 使用 state.use_ai_edit 控制是否调用 AI 生成纯净背景; - - 关闭时直接使用纯白背景。 - 4. 并行 API 调用: - - 将 Inpainting API 调用改为并行执行,加快多页处理速度。 - """ - - ocr_pages: List[Dict[str, Any]] = getattr(state, "ocr_pages", []) or [] - sam_pages: List[Dict[str, Any]] = getattr(state, "sam_pages", []) or [] - mineru_pages: List[Dict[str, Any]] = getattr(state, "mineru_pages", []) or [] - - if not ocr_pages: - log.error("[pdf2ppt_with_sam] no ocr_pages, abort PPT generation") - return state - - # 建立索引 - sam_dict = {p.get("page_idx", 0): p.get("layout_items", []) for p in sam_pages} - - # mineru_dict 存放 {"blocks": [], "mineru_output_dir": ...} - # 修复:为了防止 page_idx 类型不一致 (int vs str),构建更鲁棒的索引 - mineru_dict = {} - for p in mineru_pages: - pid = p.get("page_idx", 0) - mineru_dict[pid] = p # 原始类型 - mineru_dict[str(pid)] = p # 字符串类型兼容 - - # 以 PPT 工具里的默认比例创建 Presentation - prs = Presentation() - prs.slide_width = Inches(ppt_tool.SLIDE_W_IN) - prs.slide_height = Inches(ppt_tool.SLIDE_H_IN) - slide_w_emu = prs.slide_width - slide_h_emu = prs.slide_height - - # 初始化 base_dir,确保后续逻辑都能访问 - base_dir = Path(_ensure_result_path(state)) - - # ========================================================== - # 辅助函数:API 重试逻辑 - # ========================================================== - async def _call_image_api_with_retry(coro_factory, retries: int = 3, delay: float = 1.0) -> bool: - """ - 对图像生成/编辑进行最多 retries 次重试。 - """ - last_err: Optional[Exception] = None - for attempt in range(1, retries + 1): - try: - await coro_factory() - return True - except Exception as e: - last_err = e - log.error(f"[pdf2ppt_with_sam] image api failed attempt {attempt}/{retries}: {e}") - if attempt < retries: - try: - await asyncio.sleep(delay) - except Exception: - pass - log.error(f"[pdf2ppt_with_sam] image api failed after {retries} attempts: {last_err}") - return False - - # ========================================================== - # 辅助函数:字体和几何计算 - # ========================================================== - def _get_dominant_font_size(lines, img_h): - """计算正文文本的“众数”字号 (pt)""" - sizes = [] - for bbox, _, _ in lines: - pt = ppt_tool.estimate_font_pt(bbox, img_h_px=img_h, body_h_px=None).pt - if pt: - sizes.append(round(pt)) - if not sizes: return 12.0 - counts = Counter(sizes) - dominant = counts.most_common(1)[0][0] - return float(dominant) - - def _bbox_area(bbox): - return max(0, bbox[2] - bbox[0]) * max(0, bbox[3] - bbox[1]) - - def _get_intersection_area(bbox1, bbox2): - x1 = max(bbox1[0], bbox2[0]) - y1 = max(bbox1[1], bbox2[1]) - x2 = min(bbox1[2], bbox2[2]) - y2 = min(bbox1[3], bbox2[3]) - return max(0, x2 - x1) * max(0, y2 - y1) - - def _is_inside(inner, outer, threshold=0.9): - inter = _get_intersection_area(inner, outer) - inner_a = _bbox_area(inner) - if inner_a <= 0: return False - return (inter / inner_a) >= threshold - - def _is_overlap(bbox1, bbox2, threshold=0.1): - inter = _get_intersection_area(bbox1, bbox2) - min_area = min(_bbox_area(bbox1), _bbox_area(bbox2)) - if min_area <= 0: return False - return (inter / min_area) >= threshold - - # ========================================================== - # Phase 1: 准备渲染数据 & 创建 AI 任务 - # ========================================================== - - pages_render_data: List[Dict[str, Any]] = [] - ai_coroutines = [] # List of awaitables - - # 循环处理每一页的布局分析 - for pinfo in ocr_pages: - page_idx = pinfo.get("page_idx", 0) - - # 兼容性查找 - mineru_page_data = mineru_dict.get(page_idx) - if not mineru_page_data: - mineru_page_data = mineru_dict.get(str(page_idx), {}) - if mineru_page_data: - log.warning(f"[pdf2ppt_with_sam] page_idx mismatch fixed by str conversion: {page_idx}") - - img_path = pinfo.get("path") - lines = pinfo.get("lines", []) # List of (bbox, text, conf) - - if not img_path or not os.path.exists(img_path): - log.warning(f"[pdf2ppt_with_sam] missing img for page#{page_idx+1}: {img_path}") - continue - - # 读取原始图像信息 - try: - pil_img = Image.open(img_path) - w0, h0 = pil_img.size - except Exception as e: - log.error(f"Failed to open image {img_path}: {e}") - continue - - # ----------------------------------------------------------- - # Step 1: 分析 MinerU 结果,划定 "Image Zone" 并找回 sub_images - # ----------------------------------------------------------- - mineru_blocks = mineru_page_data.get("blocks", []) - mineru_out_dir = mineru_page_data.get("mineru_output_dir") - - image_zones = [] # List of {"bbox": [x1,y1,x2,y2], "type": str, "img_path": str} - - # 尝试定位 sub_images 目录 - sub_images_dir = None - sub_images_dirs: List[Path] = [] - if mineru_out_dir: - try: - page_root = Path(mineru_out_dir) - direct = page_root / "sub_images" - if direct.exists() and direct.is_dir(): - sub_images_dirs.append(direct) - for d in page_root.rglob("sub_images"): - if d.is_dir(): - sub_images_dirs.append(d) - seen = set() - unique_dirs: List[Path] = [] - for d in sub_images_dirs: - rp = str(d.resolve()) - if rp not in seen: - seen.add(rp) - unique_dirs.append(d) - for d in unique_dirs: - pngs = list(d.glob("*.png")) - if pngs: - sub_images_dir = d - break - if sub_images_dir: - sub_files = sorted([p.name for p in sub_images_dir.glob("*.png")]) - log.info(f"[pdf2ppt_with_sam][page#{page_idx+1}] MinerU sub_images dir: {sub_images_dir}, found {len(sub_files)} pngs") - except Exception as e: - log.error(f"[pdf2ppt_with_sam][page#{page_idx+1}] search sub_images failed: {e}") - - for idx, blk in enumerate(mineru_blocks): - btype = (blk.get("type") or "").lower() - bbox = blk.get("bbox") # norm - if not bbox or len(bbox) != 4: - continue - - x1 = int(round(bbox[0] * w0)) - y1 = int(round(bbox[1] * h0)) - x2 = int(round(bbox[2] * w0)) - y2 = int(round(bbox[3] * h0)) - - if x2 <= x1 or y2 <= y1: continue - px_bbox = [x1, y1, x2, y2] - - is_image_zone = btype in ['image', 'figure', 'table', 'formula'] - img_path_found = None - - if is_image_zone: - if blk.get("img_path") and os.path.exists(blk["img_path"]): - img_path_found = blk["img_path"] - - if not img_path_found and sub_images_dir: - try: - depth = blk.get("depth", 0) - try: - depth = int(depth) - except Exception: - depth = 0 - prefix = f"depth{depth}_blk{idx}_" - for f in sorted(sub_images_dir.glob("*.png")): - if f.name.startswith(prefix): - img_path_found = str(f.resolve()) - break - except Exception as e: - log.error(f"[pdf2ppt_with_sam][page#{page_idx+1}] match sub_images failed: {e}") - - if not img_path_found: - fallback_dir = base_dir / "mineru_fallback_crops" / f"page_{page_idx+1:03d}" - fallback_dir.mkdir(parents=True, exist_ok=True) - save_path = fallback_dir / f"mineru_{idx}_{btype}.png" - try: - if not save_path.exists(): - crop = pil_img.crop((x1, y1, x2, y2)) - crop.save(save_path) - img_path_found = str(save_path) - except Exception as e: - log.error(f"Failed to crop mineru block {idx}: {e}") - - if img_path_found: - image_zones.append({ - "bbox": px_bbox, - "type": btype, - "img_path": img_path_found - }) - - # ----------------------------------------------------------- - # Step 2: 过滤 OCR 文字 - # ----------------------------------------------------------- - final_ocr_lines = [] # (bbox, text, conf, type) - body_lines_for_stats = [] - - for line in lines: - l_bbox, l_text, l_conf = line - is_in_image = False - for zone in image_zones: - if _is_inside(l_bbox, zone["bbox"]): - is_in_image = True - break - - if not is_in_image: - l_type = "body" - for blk in mineru_blocks: - btype = (blk.get("type") or "").lower() - b_bbox = blk.get("bbox") - if not b_bbox: continue - bx1 = int(round(b_bbox[0] * w0)) - by1 = int(round(b_bbox[1] * h0)) - bx2 = int(round(b_bbox[2] * w0)) - by2 = int(round(b_bbox[3] * h0)) - - if btype in ['title', 'header'] and _is_inside(l_bbox, [bx1, by1, bx2, by2]): - l_type = "title" - break - - final_ocr_lines.append((l_bbox, l_text, l_conf, l_type)) - if l_type == "body": - body_lines_for_stats.append((l_bbox, l_text, l_conf)) - - std_body_pt = _get_dominant_font_size(body_lines_for_stats, h0) - std_title_pt = std_body_pt * 1.5 - log.info(f"[pdf2ppt_with_sam][page#{page_idx+1}] Standard Body Font: {std_body_pt}pt, Title: {std_title_pt}pt") - - # ----------------------------------------------------------- - # Step 3: 过滤 SAM 图块 - # ----------------------------------------------------------- - raw_sam_items = sam_dict.get(page_idx, []) - final_sam_items = [] - - for item in raw_sam_items: - s_bbox = item.get("bbox_px") - if not s_bbox: continue - is_in_image = False - for zone in image_zones: - if _is_inside(s_bbox, zone["bbox"], threshold=0.6): - is_in_image = True - break - if is_in_image: continue - - is_text_block = False - for line in final_ocr_lines: - l_bbox = line[0] - if _is_overlap(s_bbox, l_bbox, threshold=0.3) or _is_inside(l_bbox, s_bbox): - is_text_block = True - break - if is_text_block: continue - - w = s_bbox[2] - s_bbox[0] - h = s_bbox[3] - s_bbox[1] - if w < 5 or h < 5: continue - if w*h < 400: continue - - final_sam_items.append(item) - - # ----------------------------------------------------------- - # Step 4: 准备 AI 背景生成任务 - # ----------------------------------------------------------- - clean_bg_path = base_dir / "clean_backgrounds" / f"clean_bg_{page_idx+1:03d}.png" - clean_bg_path.parent.mkdir(parents=True, exist_ok=True) - - use_ai_bg = bool(getattr(state, "use_ai_edit", False)) - log.critical(f"[pdf2ppt 是否使用AI: ][page#{page_idx+1}] use_ai_bg={use_ai_bg}") - - ai_task = None - if use_ai_bg and os.path.exists(img_path): - try: - # A. 生成 Mask (黑底白框) - ori_cv = cv2.imread(img_path) - if ori_cv is not None: - h_cv, w_cv = ori_cv.shape[:2] - mask_cv = np.zeros((h_cv, w_cv), dtype=np.uint8) # 黑底 - - # 绘制 OCR 区域 (白框) - for line in final_ocr_lines: - bbox = line[0] - pad = 5 - mx1 = int(max(0, bbox[0] - pad)) - my1 = int(max(0, bbox[1] - pad)) - mx2 = int(min(w_cv, bbox[2] + pad)) - my2 = int(min(h_cv, bbox[3] + pad)) - cv2.rectangle(mask_cv, (mx1, my1), (mx2, my2), (255), -1) - - mask_path = base_dir / "masks" / f"mask_{page_idx+1:03d}.png" - mask_path.parent.mkdir(parents=True, exist_ok=True) - cv2.imwrite(str(mask_path), mask_cv) - - # B. 准备 AI 调用闭包 - req_cfg = getattr(state, "request", None) or {} - if not isinstance(req_cfg, dict): - req_cfg = req_cfg.__dict__ if hasattr(req_cfg, "__dict__") else {} - - api_key = req_cfg.get("api_key") or os.getenv("DF_API_KEY") - api_url = req_cfg.get("chat_api_url") or "https://api.apiyi.com" - model_name = req_cfg.get("gen_fig_model") or "gemini-3-pro-image-preview" - - if api_key: - log.info(f"[pdf2ppt_with_sam][page#{page_idx+1}] Scheduling Gemini Inpainting...") - prompt = ( - "Use the second image as a mask to remove text from the first image. " - "Fill the removed text areas with background texture to make it clean. " - "Keep non-text areas (figures, tables) unchanged." - ) - - async def _run_ai_job(_p_idx=page_idx, _img_p=img_path, _mask_p=str(mask_path), _out_p=str(clean_bg_path)): - await _call_image_api_with_retry( - lambda: gemini_multi_image_edit_async( - prompt=prompt, - image_paths=[_img_p, _mask_p], - save_path=_out_p, - api_url=api_url, - api_key=api_key, - model=model_name, - resolution="1K", - timeout=300 - ) - ) - - ai_task = _run_ai_job() - ai_coroutines.append(ai_task) - else: - log.warning("Skipping AI edit: No API Key provided") - except Exception as e: - log.error(f"[pdf2ppt_with_sam][page#{page_idx+1}] Prepare AI task failed: {e}") - - # 保存所有需要在渲染阶段使用的数据 - pages_render_data.append({ - "page_idx": page_idx, - "scale_x": slide_w_emu / w0, - "scale_y": slide_h_emu / h0, - "clean_bg_path": str(clean_bg_path), - "image_zones": image_zones, - "final_sam_items": final_sam_items, - "final_ocr_lines": final_ocr_lines, - "std_title_pt": std_title_pt, - "std_body_pt": std_body_pt, - "ai_task": ai_task # 用于追踪哪个页面发起了 AI 请求 - }) - - # ========================================================== - # Phase 2: 并发执行 AI 任务 - # ========================================================== - if ai_coroutines: - log.info(f"[pdf2ppt_with_sam] Executing {len(ai_coroutines)} AI background tasks in parallel...") - start_t = __import__("time").time() - # 忽略异常,确保后续 PPT 渲染能继续(失败的会降级为白底) - await asyncio.gather(*ai_coroutines, return_exceptions=True) - cost = __import__("time").time() - start_t - log.info(f"[pdf2ppt_with_sam] AI tasks finished. cost={cost:.2f}s") - - # ========================================================== - # Phase 3: 生成 PPT 页面 (组装) - # ========================================================== - for p_data in pages_render_data: - # 取出数据 - scale_x = p_data["scale_x"] - scale_y = p_data["scale_y"] - clean_bg_path = p_data["clean_bg_path"] - image_zones = p_data["image_zones"] - final_sam_items = p_data["final_sam_items"] - final_ocr_lines = p_data["final_ocr_lines"] - std_title_pt = p_data["std_title_pt"] - std_body_pt = p_data["std_body_pt"] - - slide = prs.slides.add_slide(prs.slide_layouts[6]) - - # 3.1 设置背景 - bg_image_path_for_ppt = None - if os.path.exists(clean_bg_path): - bg_image_path_for_ppt = clean_bg_path - - if bg_image_path_for_ppt: - try: - slide.shapes.add_picture(bg_image_path_for_ppt, 0, 0, prs.slide_width, prs.slide_height) - except Exception as e: - log.error(f"Failed to set slide background image: {e}") - # 降级 - bg = slide.background - fill = bg.fill - fill.solid() - fill.fore_color.rgb = RGBColor(255, 255, 255) - else: - bg = slide.background - fill = bg.fill - fill.solid() - fill.fore_color.rgb = RGBColor(255, 255, 255) - - # 3.2 渲染 MinerU Image Zones - for zone in image_zones: - ipath = zone["img_path"] - if not os.path.exists(ipath): - log.warning(f"MinerU image path not found: {ipath}") - continue - - bbox = zone["bbox"] - left = ppt_tool.px_to_emu(bbox[0], scale_x) - top = ppt_tool.px_to_emu(bbox[1], scale_y) - width = ppt_tool.px_to_emu(bbox[2] - bbox[0], scale_x) - height = ppt_tool.px_to_emu(bbox[3] - bbox[1], scale_y) - - try: - slide.shapes.add_picture(ipath, left, top, width, height) - except Exception as e: - log.error(f"Failed to add mineru image: {e}") - - # 3.3 渲染 SAM Icons - for item in final_sam_items: - ipath = item.get("fg_png_path") or item.get("png_path") - if not ipath or not os.path.exists(ipath): continue - - bbox = item.get("bbox_px") - left = ppt_tool.px_to_emu(bbox[0], scale_x) - top = ppt_tool.px_to_emu(bbox[1], scale_y) - width = ppt_tool.px_to_emu(bbox[2] - bbox[0], scale_x) - height = ppt_tool.px_to_emu(bbox[3] - bbox[1], scale_y) - - try: - slide.shapes.add_picture(ipath, left, top, width, height) - except Exception as e: - log.error(f"Failed to add SAM icon: {e}") - - # 3.4 渲染 OCR Text - for line in final_ocr_lines: - bbox, text, conf, l_type = line - x1, y1, x2, y2 = bbox - if (x2 - x1) < 5 or (y2 - y1) < 5: continue - - left = ppt_tool.px_to_emu(x1, scale_x) - top = ppt_tool.px_to_emu(y1, scale_y) - width = max(1, ppt_tool.px_to_emu(x2 - x1, scale_x)) - height = max(1, ppt_tool.px_to_emu(y2 - y1, scale_y)) - - tb = slide.shapes.add_textbox(left, top, width, height) - tf = tb.text_frame - tf.clear() - tf.word_wrap = True - tb.fill.background() - tb.line.fill.background() - - p = tf.paragraphs[0] - p.text = text - - if l_type == "title": - p.font.size = Pt(std_title_pt) - p.font.bold = True - else: - p.font.size = Pt(std_body_pt) - - p.font.color.rgb = RGBColor(0, 0, 0) - - # Save - # base_dir 已在函数开头定义 - ppt_path = base_dir / "pdf2ppt_with_sam_output.pptx" - prs.save(str(ppt_path)) - state.ppt_path = str(ppt_path) - log.info(f"[pdf2ppt_with_sam] PPT generated: {ppt_path}") - - return state - - nodes = { - "_start_": _init_result_path, - "pdf_to_images": pdf_to_images_node, - "slides_ocr": slides_ocr_node, - "slides_mineru": slides_mineru_node, - "slides_sam": slides_sam_node, - "slides_layout_bg_remove": slides_layout_bg_remove_node, - "slides_ppt_generation": slides_ppt_generation_node, - "_end_": lambda state: state, - } - - edges = [ - ("pdf_to_images", "slides_ocr"), - ("slides_ocr", "slides_mineru"), - ("slides_mineru", "slides_sam"), - ("slides_sam", "slides_layout_bg_remove"), - ("slides_layout_bg_remove", "slides_ppt_generation"), - ("slides_ppt_generation", "_end_"), - ] - - builder.add_nodes(nodes).add_edges(edges) - builder.add_edge("_start_", "pdf_to_images") - return builder diff --git a/dataflow_agent/workflow/wf_test_graph.py b/dataflow_agent/workflow/wf_test_graph.py deleted file mode 100644 index f8acac4..0000000 --- a/dataflow_agent/workflow/wf_test_graph.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -test_graph workflow -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -生成时间: 2025-12-01 20:16:43 - -1. 在 **TOOLS** 区域定义需要暴露给 Prompt 的前置工具 -2. 在 **NODES** 区域实现异步节点函数 (await-able) -3. 在 **EDGES** 区域声明有向边 -4. 最后返回 builder.compile() 或 GenericGraphBuilder -""" - -from __future__ import annotations -import json -from dataclasses import Field -from pydantic import BaseModel -from dataflow_agent.states.test_graph_state import TestGraphState -# from dataflow_agent.state import TestGraphState -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.workflow.registry import register -from dataflow_agent.agentroles import ( - create_agent, - create_simple_agent, - create_react_agent, - create_graph_agent, - create_vlm_agent, - SimpleConfig, - ReactConfig, - GraphConfig, - VLMConfig, - ExecutionMode, -) - -from dataflow_agent.toolkits.tool_manager import get_tool_manager -from langchain.tools import tool -from langgraph.graph import StateGraph -from langgraph.prebuilt import ToolNode, tools_condition - -from dataflow_agent.graphbuilder.graph_builder import GenericGraphBuilder -from dataflow_agent.logger import get_logger - -log = get_logger(__name__) - -@register("test_graph") -def create_test_graph_graph() -> GenericGraphBuilder: # noqa: N802 - """ - Workflow factory: dfa run --wf test_graph - """ - builder = GenericGraphBuilder(state_model=TestGraphState, - entry_point="test_graph") # 自行修改入口 - - # ---------------------------------------------------------------------- - # TOOLS (pre_tool definitions) - # ---------------------------------------------------------------------- - # 例: - @builder.pre_tool("purpose", "test_graph") - def _purpose(state: TestGraphState): - return "请问日期 11-29 的天气是什么???如果天气晴朗,请帮我购买这天的火车票!!" - - - @builder.post_tool("test_graph") - @tool - def _get_tomorrow_weather(data_str: str): - """ - 获取明天天气 - - Args: - data_str: 日期字符串,格式为 "MM-DD" - - """ - return "明天天气晴朗!!!!!!!!!!!!!" - - @builder.post_tool("test_graph") - @tool - def _get_ticket(data_str: str): - """ - 购买日期的火车票 - - Args: - data_str: 日期字符串,格式为 "MM-DD" - - """ - return "购买1张11-29的火车票!!!!!!!!!!!!!" - - - # ---------------------------------------------------------------------- - - # ============================================================== - # NODES - # ============================================================== - async def test_graph_node(state: TestGraphState) -> TestGraphState: - """ - 示例节点 1: 使用新的策略模式创建和执行 Agent - - 新版 Agent 创建方式推荐使用 `create_agent` 配合配置对象 (Config) - 或使用便捷函数 `create_simple_agent`, `create_react_agent` 等。 - - 执行模式说明: - - SimpleConfig: 简单模式,单次 LLM 调用 - - ReactConfig: ReAct 模式,带验证和重试的循环 - - GraphConfig: 图模式,用于执行带工具的子图 (LangGraph) - - VLMConfig: 视觉语言模型模式 - """ - - agent = create_graph_agent( - name="test_graph", - model_name="deepseek-v3.2", - temperature=0.1, - max_tokens=65536, - parser_type="json", - ) - - state = await agent.execute(state=state) - - log.critical(f"state.messages: {state.messages}") - - # 可选:处理执行结果 - agent_result = state.agent_results.get(agent.role_name, {}) - log.info(f"Agent {agent.role_name} 执行结果: {agent_result}") - - return state - - async def step2(state: TestGraphState) -> TestGraphState: - """ - 示例节点 2: 处理agent执行结果 - - Args: - state: 主状态对象 - """ - # TODO: 替换为真正的业务逻辑 - state.agent_results["step2"] = {"msg": "hello step2"} - - # 示例:从 step1 的结果中提取数据 - # if "code_reviewer" in state.agent_results: - # review_result = state.agent_results["code_reviewer"] - # # 处理审查结果... - - return state - - # ============================================================== - # 注册 nodes / edges - # ============================================================== - nodes = { - "test_graph": test_graph_node, - "step2": step2, - '_end_': lambda state: state, # 终止节点 - } - - # ------------------------------------------------------------------ - # EDGES (从节点 A 指向节点 B) - # ------------------------------------------------------------------ - edges = [ - ("test_graph", "step2"), - ("step2", "_end_"), # 指向终止节点 - ] - - builder.add_nodes(nodes).add_edges(edges) - return builder \ No newline at end of file diff --git a/docs/.gitkeep b/docs/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/docs/DATABASE_STRUCTURE.md b/docs/DATABASE_STRUCTURE.md deleted file mode 100644 index 1ca1604..0000000 --- a/docs/DATABASE_STRUCTURE.md +++ /dev/null @@ -1,142 +0,0 @@ -# 当前数据库里会存什么 / 能存什么(按现有表结构) - -以**当前代码 + 现有 SQL 表结构**为准,说明:创建新笔记本时库里会多什么、以及库里都有哪些表、分别存什么。 - ---- - -## 一、创建新笔记本时,数据库里会多什么? - -- **配置了 Supabase 且后端能连上时** - 会往 **`knowledge_bases`** 表里插入**一行**,例如: - - | 字段 | 含义 | 示例 | - |------|------|------| - | `id` | 笔记本唯一 ID(UUID) | 自动生成 | - | `user_id` | 所属用户(Supabase auth.users.id) | 前端传的 user_id | - | `name` | 笔记本名称 | 用户输入的「笔记本名称」 | - | `description` | 描述 | 可为空,当前前端一般传 `""` | - | `created_at` | 创建时间 | 自动 | - | `updated_at` | 更新时间 | 自动 | - - 也就是说:**创建新笔记本 = 只在 `knowledge_bases` 里新增一条记录**,不会自动在别的表里写东西。 - -- **未配置 Supabase 或连不上时** - 不会写数据库,会走本地回退:在 `outputs/kb_data/_notebooks/{user_id}.json` 里追加一条笔记本(id 为 `local_xxx` 这种)。 - ---- - -## 二、当前数据库里「有什么表」「能存什么」(按现有数据结构) - -下面都是**按现有 SQL 建表脚本**来的,执行了哪些脚本,就有哪些表。 - -### 1. `knowledge_bases`(笔记本/目录) - -- **能存什么**:每个「笔记本」一条记录。 -- **主要字段**: - - `id` (UUID),`user_id` (谁创建的),`name`(名称),`description`(描述),`created_at` / `updated_at`。 -- **谁在写**:前端点「新建笔记本」→ 后端 `POST /api/v1/kb/notebooks` → 有 Supabase 时插入这里。 - ---- - -### 2. `knowledge_base_files`(知识库文件元数据) - -- **能存什么**:每个上传到知识库的文件的**元数据**(文件名、类型、大小、存储路径、属于哪个笔记本等)。**不存文件内容**,文件本体在本地或 Storage。 -- **主要字段**: - - `id`, `user_id`, `user_email`, `file_name`, `file_type`, `file_size`, `storage_path`(路径或 URL), `is_embedded`, `kb_file_id`(向量库里的 id), **`kb_id`**(属于哪个笔记本,对应 `knowledge_bases.id`), `description`, `created_at`. -- **谁在写**:用户上传文件时,前端在 Supabase 里 insert 一行(并带上当前 `notebook.id` 作 `kb_id`);embedding 写回时可能更新 `kb_file_id`、`is_embedded` 等。 - ---- - -### 3. `kb_conversations`(对话会话) - -- **能存什么**:每个「对话」一条记录,可关联到某个笔记本(也可不关联,全局对话)。 -- **主要字段**: - - `id`, `user_id`, `user_email`, **`notebook_id`**(关联到 `knowledge_bases.id`,可为空), `title`, `created_at`, `updated_at`. -- **谁在写**:用户打开某个笔记本并发第一条消息时,后端「获取或创建」该用户+该笔记本的一条会话,插入或更新这里。 - ---- - -### 4. `kb_chat_messages`(对话里的每条消息) - -- **能存什么**:某次对话下的**每条**用户/助手消息。 -- **主要字段**: - - `id`, **`conversation_id`**(属于哪条 `kb_conversations`), `role`('user'/'assistant'/'system'), `content`(文本), `created_at`. -- **谁在写**:用户在该笔记本里发消息、后端返回回复后,后端往对应 `conversation_id` 下 append 两条:一条 user,一条 assistant。 - ---- - -### 5. `kb_output_records`(生成记录:PPT/思维导图/播客) - -- **能存什么**:每次用知识库生成 PPT、思维导图、播客的**一条记录**(类型、路径、下载地址等)。 -- **主要字段**: - - `id`, `user_id`, `user_email`, **`notebook_id`**(可选,当前代码多传 null), `output_type`('ppt'/'mindmap'/'podcast'), `file_name`, `file_path`, `result_path`, `download_url`, `extra` (JSONB), `created_at`. -- **谁在写**:后端在 `generate_ppt_from_kb` / `generate_podcast_from_kb` / `generate_mindmap_from_kb` 成功返回前调用 `_save_output_record(...)` 插入这里。 - ---- - -### 6. 其他表(01_init_schema 等里的,和「笔记本」无直接绑定) - -- **`usage_records`**:按用户、按 workflow 类型的调用记录(可做用量/配额)。 -- **`user_files`**:生成的文件的元数据(通用,不限于知识库)。 -- **`profiles`**:用户资料(如邀请码等)。 -- **`referrals`**:邀请关系。 -- **`points_ledger`**:积分流水;**`points_balance`** 为视图,算当前余额。 - -这些表**不会**在「创建新笔记本」时被写入;和当前「笔记本 + 知识库」逻辑直接相关的是上面 1~5。 - ---- - -## 三、小结(按当前逻辑 + 现有数据结构) - -- **创建新笔记本**: - 数据库里**只会**在 **`knowledge_bases`** 里多一条记录(id、user_id、name、description、时间戳)。 - 其他表不会因为「点一下创建笔记本」而自动有数据。 - -- **当前数据库里会/能存的东西**(和笔记本/知识库相关的): - 1. **笔记本本身**:`knowledge_bases` - 2. **每个笔记本里的文件元数据**:`knowledge_base_files`(通过 `kb_id` 关联笔记本) - 3. **每个笔记本的对话会话**:`kb_conversations`(通过 `notebook_id` 关联笔记本) - 4. **每条对话里的消息**:`kb_chat_messages`(通过 `conversation_id` 关联会话) - 5. **生成结果记录**:`kb_output_records`(目前多数 `notebook_id` 为 null,但表结构支持按笔记本存) - -以上都是**按现有数据结构**说明的「会存什么、能存什么」;若你后续改了表或写入逻辑,以实际代码和迁移脚本为准。 - ---- - -## 四、来源(文件列表)怎么读?和用户怎么连? - -「来源」= 左侧展示的知识库文件列表,**只在前端读**,和用户的绑定靠 **user_id(和可选 notebook id)**。 - -### 1. 用户是谁(和谁连) - -- **配置了 Supabase 时** - - 登录态来自 `supabase.auth.getSession()` / `onAuthStateChange`。 - - 前端把 `session` 放进 `authStore`,用到的用户标识是 **`user.id`**(Supabase `auth.users.id`)和 `user.email`。 - - 所以「当前用户」= 当前 Supabase 登录用户的 `user.id`。 - -- **未配置 Supabase 时** - - 使用 mock 用户:`user.id = 'dev-user-001'`,`user.email = 'dev@notebook.local'`。 - - 所有「来源」和「笔记本」都按这个 mock 用户隔离(本地 JSON + localStorage)。 - -### 2. 来源(文件列表)怎么读——和笔记本绑定,每个笔记本独立来源 - -- **配置了 Supabase 且当前笔记本来自数据库(UUID)** - - 在 `NotebookView` 里调 `fetchFiles()`: - - `supabase.from('knowledge_base_files').select('*').eq('user_id', user?.id).eq('kb_id', notebook.id)` - - 即:**只读该笔记本下的来源**(按 `user_id` + `kb_id`)。 - - 上传时插入 `knowledge_base_files` 并带 `kb_id = notebook.id`,所以来源和笔记本一一对应。 - -- **本地笔记本(id 以 `local_` 开头)或未配置 Supabase** - - 从 **localStorage** 读:key = `kb_files_${user.id}_${notebook.id}`(有笔记本时),这样**每个笔记本一个 key,互不共用**。 - - 上传时只往当前笔记本对应的 key 里 append,不写入 Supabase(本地笔记本没有对应 `knowledge_bases` 行)。 - -### 3. 和用户、笔记本的连接总结 - -| 环节 | 和用户、笔记本怎么连 | -|--------------|----------------------| -| 用户身份 | Supabase:`auth.session.user.id` → `user.id`;未配置:mock `user.id` | -| 来源列表读取 | **每个笔记本独立**:Supabase 笔记本按 `user_id` + `kb_id` 查;本地笔记本按 localStorage `kb_files_${user.id}_${notebook.id}` | -| 上传文件落库 | 数据库笔记本:insert 带 `user_id` + `kb_id`;本地笔记本:只写 localStorage 对应 key | -| 笔记本列表 | 后端 GET /kb/notebooks 传 `user_id`;Supabase 表 `knowledge_bases` 按 `user_id` 查 | - -所以:**来源和笔记本一一对应**:每个笔记本只显示、只写入该笔记本下的来源;和用户的连接是 **user_id**,和笔记本的连接是 **kb_id**(数据库)或 **localStorage key 里的 notebook.id**(本地)。 diff --git a/docs/SUPABASE_CONFIG.md b/docs/SUPABASE_CONFIG.md deleted file mode 100644 index c5e634f..0000000 --- a/docs/SUPABASE_CONFIG.md +++ /dev/null @@ -1,90 +0,0 @@ -# Supabase / 数据库配置说明 - -项目里**数据库(Supabase)**的配置分布在**后端**和**前端**两处,通过环境变量传入。 - ---- - -## 1. 后端(FastAPI) - -### 配置位置 - -- **读取代码**:`fastapi_app/dependencies/auth.py` - - `get_supabase_client()`:用 **anon key**,用于 JWT 校验、前端直连时的 RLS - - `get_supabase_admin_client()`:用 **service role key**,用于服务端写库(对话、输出、笔记本等) - -- **环境变量**(需在**运行后端的进程**里生效): - - | 变量名 | 用途 | 必填 | - |--------|------|------| - | `SUPABASE_URL` | 项目 URL,如 `https://xxxx.supabase.co` | 用 Supabase 时必填 | - | `SUPABASE_ANON_KEY` | 匿名公钥(Settings → API → anon public) | JWT 校验时必填 | - | `SUPABASE_SERVICE_ROLE_KEY` | 服务端密钥(Settings → API → service_role) | 后端写库时必填 | - -- **写入方式**:在 **`fastapi_app/.env`** 中配置(若用 dotenv),或启动前在 shell 里 `export`。 - -### 后端如何读到 .env - -- `fastapi_app/config/settings.py` 里 Pydantic 的 `env_file = ".env"` 只影响 **settings** 里的配置项,**不会**自动给 `os.getenv()` 用的 Supabase 变量加料。 -- 若希望用 `fastapi_app/.env` 给 Supabase 用,需要: - - 在**项目根**或 **fastapi_app** 目录下启动(且把 `.env` 放在同一目录),并确保有地方执行 `load_dotenv()`(例如在 `main.py` 最开头加 `from dotenv import load_dotenv; load_dotenv()`),或 - - 在启动命令前 `export` 上述三个变量。 - -**示例 `fastapi_app/.env`:** - -```env -SUPABASE_URL=https://你的项目.supabase.co -SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9... -SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9... -``` - ---- - -## 2. 前端(Vite / frontend-v2) - -### 配置位置 - -- **读取代码**:`frontend-v2/src/lib/supabase.ts` - - 用 `import.meta.env.VITE_SUPABASE_URL` 和 `import.meta.env.VITE_SUPABASE_ANON_KEY` 创建 Supabase 客户端。 - - `isSupabaseConfigured()` 为 true 时才会用 Supabase(登录、知识库文件、表 `knowledge_base_files` 等)。 - -- **环境变量**(必须以 `VITE_` 开头,构建/开发时注入): - - `VITE_SUPABASE_URL`:同后端,项目 URL。 - - `VITE_SUPABASE_ANON_KEY`:同后端的 anon key(**不要**在前端放 service_role key)。 - -- **写入方式**:在 **`frontend-v2/.env`** 或 `frontend-v2/.env.local` 中配置。 - -**示例 `frontend-v2/.env`:** - -```env -VITE_SUPABASE_URL=https://你的项目.supabase.co -VITE_SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9... -``` - ---- - -## 3. 配置汇总表 - -| 用途 | 后端 env | 前端 env | 说明 | -|------|----------|----------|------| -| 项目 URL | `SUPABASE_URL` | `VITE_SUPABASE_URL` | 同一项目填同一个 URL | -| 匿名 key | `SUPABASE_ANON_KEY` | `VITE_SUPABASE_ANON_KEY` | 前端 + 后端 JWT 校验 | -| 服务端 key | `SUPABASE_SERVICE_ROLE_KEY` | 不配置 | 仅后端,用于写库 | - ---- - -## 4. 未配置时的行为 - -- **后端**:`get_supabase_admin_client()` 返回 `None` 时,笔记本/对话/输出会走**本地回退**(本地 JSON 文件或磁盘扫描),不会报错。 -- **前端**:未配置 `VITE_SUPABASE_*` 时,`isSupabaseConfigured()` 为 false,使用 mock 用户和 localStorage,不连 Supabase。 - ---- - -## 5. 表结构(Supabase SQL Editor) - -若使用 Supabase,需在项目中执行建表脚本(通常位于 `database/` 或项目根下的 SQL 文件),例如: - -- `01_init_schema.sql`:基础表 + `knowledge_bases`、`knowledge_base_files` 等 -- `05_kb_conversations.sql`:对话与消息表 -- `06_kb_output_records.sql`:生成记录表 - -在 Supabase 控制台 → SQL Editor 中执行对应脚本即可。 diff --git a/docs/changelog.md b/docs/changelog.md deleted file mode 100644 index 7d016cc..0000000 --- a/docs/changelog.md +++ /dev/null @@ -1,4 +0,0 @@ -# 更新日志 - -## 0.1.0 - 2024-05-01 -- 🎉 首次发布 diff --git a/docs/cli.md b/docs/cli.md deleted file mode 100644 index 4622217..0000000 --- a/docs/cli.md +++ /dev/null @@ -1,385 +0,0 @@ -# 🛠️ Paper2Any CLI 脚手架使用说明 - -Paper2Any 内置了一套基于 Jinja2 模板的 CLI 代码生成工具(来自 DataFlow-Agent 框架),可以快速生成 **Agent / Workflow / Gradio 页面 / Prompt 模板 / State / Agent-as-Tool** 等标准化代码文件,极大提升开发效率。 - -> CLI 可执行入口通常为 `dfa`(或等价的 Python entrypoint),下文统一使用: -> -> ```bash -> dfa create ... -> ``` - ---- - -## 功能总览 - -CLI 提供以下代码模板类型: - -| 命令参数 | 功能说明 | 生成文件示例 | 自动集成能力 | -|-------------------------|--------------------|-----------------------------------------------|----------------------------| -| `--agent_name` | 创建 Agent 角色 | `agentroles/{name}_agent.py` | ✅ `@register` 自动注册 | -| `--wf_name` | 创建 Workflow | `workflow/wf_{name}.py` + `tests/test_{name}.py` | ✅ Workflow 自动注册 | -| `--gradio_name` | 创建 Gradio 页面 | `gradio_app/pages/page_{name}.py` | ✅ 页面自动发现 | -| `--prompt_name` | 创建 Prompt 模板库 | `promptstemplates/resources/pt_{name}_repo.py` | 手动在 Agent 中引用 | -| `--state_name` | 创建自定义 State | `states/{name}_state.py` | 手动在 Workflow / Agent 中使用 | -| `--agent_as_tool_name` | 创建 Agent 工具 | `agentroles/{name}_agent.py` | ✅ `@register` + Tool 集成 | - ---- - -## 基本用法 - -```bash -# 查看帮助(如 CLI 支持) -dfa --help -dfa create --help - -# 典型使用:根据不同参数生成对应模板 -dfa create --agent_name my_agent -dfa create --wf_name text_pipeline -dfa create --gradio_name paper2figure -dfa create --prompt_name code_review -dfa create --state_name image_processing -dfa create --agent_as_tool_name text_summarizer -``` - ---- - -## 1. 创建 Agent 角色 - -### 命令示例 - -```bash -dfa create --agent_name sentiment_analyzer -``` - -### 生成内容 - -- 文件路径(示例): - `dataflow_agent/agentroles/common_agents/sentiment_analyzer_agent.py` - -### 主要特性 - -- ✅ 自动注册到 Agent 注册中心(`@register("sentiment_analyzer")`) -- ✅ 预置 `BaseAgent` 继承结构 -- ✅ 预留 prompt 配置接口 -- ✅ 支持 Simple / ReAct / Graph / VLM 等多种执行策略扩展 -- ✅ 提供异步执行入口函数和工厂方法 - -### 典型代码结构(示意) - -```python -from dataflow_agent.agentroles.base_agent import BaseAgent -from dataflow_agent.agentroles.registry import register -from dataflow_agent.states import MainState - -@register("sentiment_analyzer") -class SentimentAnalyzer(BaseAgent): - """情感分析 Agent 示例""" - - @property - def system_prompt_template_name(self) -> str: - # 返回系统 Prompt 名称(在 promptstemplates 中定义) - return "system_prompt_for_sentiment_analyzer" - - def get_task_prompt_params(self, pre_tool_results) -> dict: - # TODO: 自定义参数映射逻辑 - return {} - - @classmethod - def create(cls, tool_manager=None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) - - async def execute(self, state: MainState) -> MainState: - # TODO: 实现核心业务逻辑 - return state - - -# 便捷调用函数 -async def sentiment_analyzer(state: MainState, **kwargs) -> MainState: - agent = SentimentAnalyzer.create(**kwargs) - return await agent.execute(state) -``` - ---- - -## 2. 创建 Workflow(工作流) - -### 命令示例 - -```bash -dfa create --wf_name text_processing -``` - -### 生成内容 - -- Workflow 定义:`dataflow_agent/workflow/wf_text_processing.py` -- 测试用例:`tests/test_text_processing.py` - -### 主要特性 - -- ✅ 自动注册到 Workflow 注册中心(如 `@register("text_processing")`) -- ✅ 基于 StateGraph 的节点/边定义框架 -- ✅ 预置 pre_tool / post_tool 装饰器示例 -- ✅ 包含完整单元测试模板 -- ✅ 支持多种 Agent 创建与组合模式 - -### Workflow 代码结构(示意) - -```python -from dataflow_agent.workflow import GenericGraphBuilder, register_workflow -from dataflow_agent.states import MainState - -@register_workflow("text_processing") -def create_text_processing_graph() -> GenericGraphBuilder: - builder = GenericGraphBuilder(state_model=MainState, entry_point="step1") - - # 前置工具(可选) - @builder.pre_tool("purpose", "step1") - def _purpose(state: MainState): - return "工具描述或提示词参数" - - # 节点定义 - async def step1(state: MainState) -> MainState: - # 可以在这里创建/调用 Agent - # agent = create_simple_agent(name="your_agent", ...) - # return await agent.execute(state) - return state - - # 注册节点与边 - builder.add_nodes({ - "step1": step1, - }).add_edges([ - ("step1", "_end_"), - ]) - - return builder -``` - -### 测试示例(简化) - -```python -from dataflow_agent.workflow import run_workflow -from dataflow_agent.states import MainState - -async def test_text_processing(): - state = MainState(messages=["hello"]) - result = await run_workflow("text_processing", state=state) - assert isinstance(result, MainState) -``` - -运行测试: - -```bash -pytest tests/test_text_processing.py -v -s -``` - ---- - -## 3. 创建 Gradio 页面 - -### 命令示例 - -```bash -dfa create --gradio_name model_hub -``` - -### 生成内容 - -- 文件:`gradio_app/pages/page_model_hub.py` - -### 主要特性 - -- ✅ 自动被 `gradio_app/app.py` 发现和注册到 Tab -- ✅ 统一命名:`create_{page_name}` 函数 -- ✅ 内置 Gradio UI 组件示例 -- ✅ 预留调用 Workflow / Agent 的逻辑框架 - -### 页面结构示例 - -```python -import gradio as gr - -def create_model_hub() -> gr.Blocks: - with gr.Blocks() as page: - gr.Markdown("## Model Hub") - # TODO: 添加输入输出组件、按钮、回调等 - return page - -# 示例:调用 workflow 的占位函数 -async def run_xxx_pipeline(...): - # TODO: 调用 dataflow_agent.workflow.run_workflow(...) - # state = await run_workflow("wf_xxx", state) - # return state - return ... -``` - -> 重启 `python gradio_app/app.py` 后,新页面会自动出现在 Web 界面的 Tab 中。 - ---- - -## 4. 创建 Prompt 模板库 - -### 命令示例 - -```bash -dfa create --prompt_name code_review -``` - -### 生成内容 - -- 文件:`dataflow_agent/promptstemplates/resources/pt_code_review_repo.py` - -### 代码结构示例 - -```python -class CodeReview: - task_prompt_for_example = """ - Your task description here. - Input: {input_data} - """ - - system_prompt_for_example = """ - You are an AI assistant for code review tasks. - """ -``` - -### 使用方式(在 Agent 中调用) - -```python -from dataflow_agent.promptstemplates.resources.pt_code_review_repo import CodeReview - -# 在 Agent 中指定模板名 -@property -def task_prompt_template_name(self) -> str: - return "task_prompt_for_example" -``` - ---- - -## 5. 创建自定义 State - -### 命令示例 - -```bash -dfa create --state_name image_processing -``` - -### 生成内容 - -- 文件:`dataflow_agent/states/image_processing_state.py` - -### 代码结构示例 - -```python -from dataclasses import dataclass, field -from dataflow_agent.states import MainRequest, MainState - -@dataclass -class ImageProcessingRequest(MainRequest): - """自定义请求参数""" - # TODO: 在这里增加你的字段 - pass - -@dataclass -class ImageProcessingState(MainState): - """自定义状态对象""" - request: ImageProcessingRequest = field(default_factory=ImageProcessingRequest) -``` - -### 使用方式 - -```python -from dataflow_agent.states.image_processing_state import ImageProcessingState - -state = ImageProcessingState(messages=[]) -``` - ---- - -## 6. 创建 Agent-as-Tool(可作为 Tool 被调用的 Agent) - -### 命令示例 - -```bash -dfa create --agent_as_tool_name text_summarizer -``` - -### 生成内容 - -- 文件:`dataflow_agent/agentroles/text_summarizer_agent.py` - -### 主要特性 - -- ✅ 既可作为普通 Agent 使用 -- ✅ 又可作为 Tool 提供给其他 Agent / Workflow 调用 -- ✅ 支持自定义工具描述和参数 Schema -- ✅ 自动完成参数解析与映射 - -### 代码结构示例 - -```python -from pydantic import BaseModel, Field -from dataflow_agent.agentroles.base_agent import BaseAgent -from dataflow_agent.agentroles.registry import register - -@register("text_summarizer") -class TextSummarizer(BaseAgent): - """文本总结 Agent / Tool""" - - def get_tool_description(self) -> str: - return "用于总结文本内容" - - def get_tool_args_schema(self) -> type[BaseModel]: - class SummarizerArgs(BaseModel): - content: str = Field(description="要总结的内容") - max_length: int = Field(default=500, description="摘要最大长度") - return SummarizerArgs - - async def execute(self, state): - # TODO: 实现核心 summarization 逻辑 - return state -``` - -### 作为 Tool 使用(示意) - -```python -# 在其他 Agent 或 Workflow 中 -# 例如在 ReAct / Graph Agent 中启用工具模式 -# text_summarizer 会自动出现在可用工具列表中 -``` - ---- - -## 7. 模板通用特性 - -- 🕐 **时间戳**:生成文件通常包含创建时间注释,方便追踪。 -- 🔤 **智能命名**:自动处理 snake_case / CamelCase 转换。 -- 📝 **TODO 标记**:关键位置预留 `TODO` 注释,指引你补充业务逻辑。 -- 🎯 **最佳实践**:遵循项目内部约定的编码风格与结构。 -- 🔗 **自动集成**: - - Agent / Workflow 自动注册; - - Gradio 页面自动发现; - - Prompt / State 模板方便在 Agent / Workflow 中复用。 - ---- - -## 8. 命名规范与自动转换 - -CLI 会对输入名称进行规范化处理,保证**文件名、类名、注册名**统一: - -```bash -# 以下三种写法等价 -dfa create --agent_name "My Data Processor" -dfa create --agent_name "my-data-processor" -dfa create --agent_name "my_data_processor" - -# 统一转换为: -# - 文件名: my_data_processor_agent.py -# - 类名: MyDataProcessor -# - 注册名: "my_data_processor" -``` - -> 建议尽量使用语义清晰的英文名称,方便在大型 Workflow 中组织与检索。 - ---- - -以上即为 Paper2Any / DataFlow-Agent CLI 脚手架的整理版说明,建议你在本项目中创建 Agent / Workflow / 页面时优先使用 CLI,以保持代码风格一致并提升开发效率。 diff --git a/docs/contributing.md b/docs/contributing.md deleted file mode 100644 index 4756680..0000000 --- a/docs/contributing.md +++ /dev/null @@ -1,75 +0,0 @@ -## 🤝 贡献指南 - -### 开发流程 - -```bash -# 1. Fork并克隆 -git clone https://github.com//DataFlow-Agent.git -cd DataFlow-Agent - -# 2. 安装开发依赖 -pip install -r requirements-dev.txt -pip install -e . - -# 3. 创建分支 -git checkout -b feature/your-feature - -# 4. 运行测试 -pytest - -# 5. 提交PR -git push origin feature/your-feature -``` - -### 添加新Agent - -```python -from dataflow_agent.agentroles.base_agent import BaseAgent -from dataflow_agent.agentroles.registry import register - -@register("my_agent") # 自动注册 -class MyAgent(BaseAgent): - @classmethod - def create(cls, tool_manager=None, **kwargs): - return cls(tool_manager=tool_manager, **kwargs) -``` - -### 添加新Workflow - -```python -# 文件: dataflow_agent/workflow/wf_my_workflow.py -from dataflow_agent.workflow.registry import register -from dataflow_agent.graphbuilder import GraphBuilder - -@register("my_workflow") # 注册名 = 文件名去掉wf_前缀 -def create_my_workflow_graph(): - builder = GraphBuilder() - # 定义节点和边... - return builder -``` - -### 添加Gradio页面 - -```python -# 文件: gradio_app/pages/my_page.py -import gradio as gr - -def create_my_page(): # 函数名 = create_ + 文件名 - with gr.Blocks() as page: - gr.Markdown("## 我的页面") - # 添加组件... - return page -``` - -### 文档贡献 - -```bash -# 本地预览 -pip install mkdocs-material -mkdocs serve # 访问 http://127.0.0.1:8000 - -# 添加新页面 -# 1. 在docs/对应目录创建.md文件 -# 2. 在mkdocs.yml的nav中添加链接 -# 3. 提交PR -``` \ No newline at end of file diff --git a/docs/faq.md b/docs/faq.md deleted file mode 100644 index 49f8650..0000000 --- a/docs/faq.md +++ /dev/null @@ -1,5 +0,0 @@ -# 常见问题 FAQ - -### Q1. 为什么运行内存占用高? - -A: 请确认是否启用了 `--debug` 模式,详见文档。 diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md deleted file mode 100644 index d1953cb..0000000 --- a/docs/guides/configuration.md +++ /dev/null @@ -1,662 +0,0 @@ -# 🔧 开发者配置指南 - -## 📖 简介 - -本文档面向 Paper2Any 项目的开发者,详细讲解如何从零开始配置项目环境、配置模型服务、并成功启动整个系统。 - -通过本指南,你将学会: -- 如何正确配置前端和后端的环境变量 -- 如何理解和使用三层模型配置架构 -- 如何配置和启动模型服务器集群 -- 如何排查常见的配置问题 - -## 📋 配置文件概览 - -Paper2Any 项目包含以下主要配置文件: - -| 配置文件 | 路径 | 用途 | -|---------|------|------| -| 前端环境变量 | `frontend-workflow/.env.example` | 配置前端 API 通信、LLM 提供商、Supabase | -| 后端环境变量 | `fastapi_app/.env.example` | 配置后端模型、数据库、API 服务 | -| 模型服务器启动脚本 | `script/start_model_servers.sh` | 配置 MinerU、SAM、OCR 等模型服务 | - -### 配置文件之间的关系 - -``` -┌─────────────────────────────────────────────────────────────┐ -│ 用户浏览器 │ -└────────────────────────┬────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ 前端 (frontend-workflow) │ -│ 配置文件: .env │ -│ - VITE_API_KEY: 与后端通信的密钥 │ -│ - VITE_DEFAULT_LLM_API_URL: 默认 LLM API 地址 │ -│ - VITE_SUPABASE_URL: 用户认证服务(可选) │ -└────────────────────────┬────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ 后端 (fastapi_app) │ -│ 配置文件: .env │ -│ - 三层模型配置架构 │ -│ - DEFAULT_LLM_API_URL: LLM API 服务地址 │ -│ - SUPABASE_*: 数据库和认证配置(可选) │ -└────────────────────────┬────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ 模型服务器集群 │ -│ 启动脚本: script/start_model_servers.sh │ -│ - MinerU (vLLM): 文档理解模型 │ -│ - SAM: 图像分割模型 │ -│ - OCR: 光学字符识别服务 │ -└─────────────────────────────────────────────────────────────┘ -``` - -## 🎨 前端配置详解 - -### 步骤 1: 创建前端配置文件 - -```bash -cd frontend-workflow -cp .env.example .env -``` - -### 步骤 2: 配置内部 API 通信 - -前端和后端之间通过 API 密钥进行安全通信,**这个密钥必须与后端配置保持一致**。 - -```bash -# frontend-workflow/.env -VITE_API_KEY=df-internal-2024-workflow-key -``` - -⚠️ **重要提示**: -- 这个密钥必须与 `fastapi_app/.env` 中的 `API_KEY` 完全一致 -- 生产环境中请修改为更安全的密钥 -- 不要将包含真实密钥的 `.env` 文件提交到版本控制系统 - -### 步骤 3: 配置 LLM 提供商 - -前端需要配置默认的 LLM API 地址,用户可以在界面上选择不同的 API 提供商。 - -```bash -# frontend-workflow/.env - -# 默认 LLM API URL(在 UI 的"API URL"输入框中显示) -VITE_DEFAULT_LLM_API_URL=https://api.apiyi.com/v1 - -# 可选的 LLM API URL 列表(逗号分隔,用户可在 UI 中选择) -VITE_LLM_API_URLS=https://api.apiyi.com/v1,http://b.apiyi.com:16888/v1,http://123.129.219.111:3000/v1 -``` - -**配置说明**: -- `VITE_DEFAULT_LLM_API_URL`: 前端界面默认显示的 API 地址 -- `VITE_LLM_API_URLS`: 用户可以在下拉菜单中选择的 API 地址列表 -- 用户可以在生成内容时覆盖这些默认值 - -**常见 LLM API 提供商**: -- OpenAI 官方: `https://api.openai.com/v1` -- 阿里云百炼: `https://dashscope.aliyuncs.com/compatible-mode/v1` -- DeepSeek: `https://api.deepseek.com/v1` -- 自建代理服务: 根据你的实际部署地址配置 - -### 步骤 4: 配置 Supabase(可选) - -如果你需要用户认证、配额管理和云存储功能,需要配置 Supabase。 - -```bash -# frontend-workflow/.env - -# 取消注释并填写以下配置 -VITE_SUPABASE_URL=https://your-project.supabase.co -VITE_SUPABASE_ANON_KEY=your-anon-key -SUPABASE_SERVICE_ROLE_KEY=your-service-role-key -SUPABASE_JWT_SECRET=your-jwt-secret -``` - -**获取 Supabase 配置**: -1. 访问 [Supabase Dashboard](https://supabase.com/dashboard) -2. 选择你的项目 -3. 进入 Settings → API -4. 复制 Project URL 和 API Keys - -**如果不使用 Supabase**: -- 保持这些配置项注释状态 -- 项目将使用本地存储和无认证模式运行 - -### 前端配置完整示例 - -```bash -# =========================================== -# Internal API Configuration -# =========================================== -VITE_API_KEY=df-internal-2024-workflow-key - -# =========================================== -# LLM Provider Configuration -# =========================================== -VITE_DEFAULT_LLM_API_URL=https://api.openai.com/v1 -VITE_LLM_API_URLS=https://api.openai.com/v1,https://api.deepseek.com/v1,http://localhost:3000/v1 - -# =========================================== -# Supabase Configuration (Optional) -# =========================================== -# VITE_SUPABASE_URL=https://your-project.supabase.co -# VITE_SUPABASE_ANON_KEY=your-anon-key -# SUPABASE_SERVICE_ROLE_KEY=your-service-role-key -# SUPABASE_JWT_SECRET=your-jwt-secret -``` - -## ⚙️ 后端配置详解 - -### 步骤 1: 创建后端配置文件 - -```bash -cd fastapi_app -cp .env.example .env -``` - -### 步骤 2: 配置 Supabase(可选) - -如果前端配置了 Supabase,后端也需要相应配置。 - -```bash -# fastapi_app/.env -SUPABASE_URL=https://your-project-id.supabase.co -SUPABASE_ANON_KEY=your_supabase_anon_key -``` - -### 步骤 3: 理解三层模型配置架构 🎯 - -Paper2Any 采用了灵活的**三层模型配置架构**,让你可以在不同粒度上控制模型选择: - -``` -Layer 1: 基础模型定义 - ↓ -Layer 2: 工作流级别默认模型 - ↓ -Layer 3: 角色级别精细控制 -``` - -#### Layer 1: 基础模型定义 - -定义所有可用的模型名称,这些名称会被后续配置引用。 - -```bash -# fastapi_app/.env - -# ============================================ -# Model Configuration - Layer 1: Base Models -# ============================================ -MODEL_GPT_4O=gpt-4o -MODEL_GPT_5_1=gpt-5.1 -MODEL_CLAUDE_HAIKU=claude-haiku-4-5-20251001 -MODEL_GEMINI_PRO_IMAGE=gemini-3-pro-image-preview -MODEL_GEMINI_FLASH_IMAGE=gemini-2.5-flash-image -MODEL_GEMINI_FLASH=gemini-2.5-flash -MODEL_QWEN_VL_OCR=qwen-vl-ocr-2025-11-20 - -# 默认 LLM API URL(内部服务) -DEFAULT_LLM_API_URL=http://123.129.219.111:3000/v1/ -``` - -**配置说明**: -- 这一层定义了所有可用模型的"别名" -- 你可以根据实际使用的 API 提供商修改模型名称 -- `DEFAULT_LLM_API_URL` 是后端调用 LLM 的默认地址 - -#### Layer 2: 工作流级别默认模型 - -为每个工作流设置默认模型,快速切换整个工作流的模型。 - -```bash -# ============================================ -# Model Configuration - Layer 2: Workflow-level Defaults -# ============================================ - -# Paper2PPT 工作流 -PAPER2PPT_DEFAULT_MODEL=gpt-5.1 -PAPER2PPT_DEFAULT_IMAGE_MODEL=gemini-3-pro-image-preview - -# PDF2PPT 工作流 -PDF2PPT_DEFAULT_MODEL=gpt-4o -PDF2PPT_DEFAULT_IMAGE_MODEL=gemini-2.5-flash-image - -# Paper2Figure 工作流 -PAPER2FIGURE_DEFAULT_MODEL=gpt-4o -PAPER2FIGURE_DEFAULT_IMAGE_MODEL=gemini-3-pro-image-preview - -# Paper2Video 工作流 -PAPER2VIDEO_DEFAULT_MODEL=gpt-4o - -# Knowledge Base -KB_EMBEDDING_MODEL=gemini-2.5-flash -KB_CHAT_MODEL=gpt-4o -``` - -**使用场景**: -- 想要快速切换某个工作流使用的模型 -- 例如:将 Paper2PPT 从 GPT-4o 切换到 Claude Haiku -- 只需修改 `PAPER2PPT_DEFAULT_MODEL=claude-haiku-4-5-20251001` - -#### Layer 3: 角色级别精细控制 - -为工作流中的每个具体角色(任务)指定模型,实现最精细的控制。 - -```bash -# ============================================ -# Model Configuration - Layer 3: Role-level (Fine-grained Control) -# ============================================ - -# Paper2PPT 角色配置 -PAPER2PPT_OUTLINE_MODEL=gpt-5.1 # 大纲生成 -PAPER2PPT_CONTENT_MODEL=gpt-5.1 # 内容生成 -PAPER2PPT_IMAGE_GEN_MODEL=gemini-3-pro-image-preview # 图像生成 -PAPER2PPT_VLM_MODEL=qwen-vl-ocr-2025-11-20 # 视觉语言模型(OCR) -PAPER2PPT_CHART_MODEL=gpt-4o # 图表生成 -PAPER2PPT_DESC_MODEL=gpt-5.1 # 图表描述 -PAPER2PPT_TECHNICAL_MODEL=claude-haiku-4-5-20251001 # 技术细节 - -# Paper2Figure 角色配置 -PAPER2FIGURE_TEXT_MODEL=gpt-4o -PAPER2FIGURE_IMAGE_MODEL=gemini-3-pro-image-preview -PAPER2FIGURE_VLM_MODEL=qwen-vl-ocr-2025-11-20 -PAPER2FIGURE_CHART_MODEL=gpt-4o -PAPER2FIGURE_DESC_MODEL=gpt-5.1 -PAPER2FIGURE_TECHNICAL_MODEL=claude-haiku-4-5-20251001 -``` - -**使用场景**: -- 针对特定任务优化模型选择 -- 例如:OCR 任务使用专门的视觉模型 `qwen-vl-ocr` -- 技术细节提取使用 Claude Haiku(成本更低) -- 图像生成使用 Gemini Pro(效果更好) - -### 步骤 4: 理解配置优先级 - -三层配置的优先级从高到低: - -``` -Layer 3 (角色级别) > Layer 2 (工作流级别) > Layer 1 (基础定义) -``` - -**实际运行逻辑**: -1. 系统首先查找 Layer 3 的角色级别配置 -2. 如果未配置,则使用 Layer 2 的工作流级别默认值 -3. 如果仍未配置,则使用 Layer 1 定义的基础模型 - -**实践示例**: - -假设你想让 Paper2PPT 的大纲生成使用 Claude Haiku(更便宜),但其他任务仍使用 GPT-5.1: - -```bash -# Layer 2: 工作流默认使用 GPT-5.1 -PAPER2PPT_DEFAULT_MODEL=gpt-5.1 - -# Layer 3: 只有大纲生成使用 Claude Haiku -PAPER2PPT_OUTLINE_MODEL=claude-haiku-4-5-20251001 -# 其他角色不配置,自动继承 Layer 2 的 gpt-5.1 -``` - -### 步骤 5: 后端配置完整示例 - -```bash -# ============================================ -# Supabase Configuration (Optional) -# ============================================ -SUPABASE_URL=https://your-project-id.supabase.co -SUPABASE_ANON_KEY=your_supabase_anon_key - -# ============================================ -# Model Configuration - Layer 1: Base Models -# ============================================ -MODEL_GPT_4O=gpt-4o -MODEL_GPT_5_1=gpt-5.1 -MODEL_CLAUDE_HAIKU=claude-haiku-4-5-20251001 -MODEL_GEMINI_PRO_IMAGE=gemini-3-pro-image-preview -MODEL_GEMINI_FLASH_IMAGE=gemini-2.5-flash-image -MODEL_GEMINI_FLASH=gemini-2.5-flash -MODEL_QWEN_VL_OCR=qwen-vl-ocr-2025-11-20 - -DEFAULT_LLM_API_URL=https://api.openai.com/v1/ - -# ============================================ -# Model Configuration - Layer 2: Workflow-level Defaults -# ============================================ -PAPER2PPT_DEFAULT_MODEL=gpt-4o -PAPER2PPT_DEFAULT_IMAGE_MODEL=gemini-3-pro-image-preview - -PDF2PPT_DEFAULT_MODEL=gpt-4o -PDF2PPT_DEFAULT_IMAGE_MODEL=gemini-2.5-flash-image - -# ============================================ -# Model Configuration - Layer 3: Role-level (Fine-grained Control) -# ============================================ -# 只配置需要特殊处理的角色,其他角色自动继承 Layer 2 -PAPER2PPT_VLM_MODEL=qwen-vl-ocr-2025-11-20 -PAPER2PPT_TECHNICAL_MODEL=claude-haiku-4-5-20251001 -``` - -## 🚀 模型服务器配置和启动 - -### 模型服务器架构 - -Paper2Any 使用本地模型服务器集群来处理文档解析和图像分割任务: - -``` -┌─────────────────────────────────────────────────────────────┐ -│ MinerU 集群 (vLLM) │ -│ - GPU 7, 1, 2, 3 │ -│ - 端口: 8011, 8012, 8013, 8014 │ -│ - 负载均衡器: 8010 │ -└─────────────────────────────────────────────────────────────┘ - -┌─────────────────────────────────────────────────────────────┐ -│ SAM 集群 (图像分割) │ -│ - GPU 4, 5, 6 │ -│ - 端口: 8021, 8022, 8023 │ -│ - 负载均衡器: 8020 │ -└─────────────────────────────────────────────────────────────┘ - -┌─────────────────────────────────────────────────────────────┐ -│ OCR 服务 (CPU) │ -│ - 端口: 8003 │ -│ - Workers: 4 │ -└─────────────────────────────────────────────────────────────┘ -``` - -### 步骤 1: 配置启动脚本 - -编辑 `script/start_model_servers.sh` 文件,根据你的硬件配置调整参数。 - -#### MinerU 配置 - -```bash -# MinerU Config -MINERU_MODEL="models/MinerU2.5-2509-1.2B" # 模型路径 -MINERU_GPU_UTIL=0.85 # GPU 显存利用率 (0-1) -MINERU_MAX_SEQS=64 # 最大并发序列数 -MINERU_GPUS=(7 1 2 3) # 使用的 GPU ID -MINERU_START_PORT=8011 # 起始端口号 -``` - -**配置说明**: -- `MINERU_MODEL`: MinerU 模型文件路径,需要提前下载 -- `MINERU_GPU_UTIL`: GPU 显存利用率,建议 0.8-0.9 -- `MINERU_MAX_SEQS`: 并发处理的序列数,影响吞吐量 -- `MINERU_GPUS`: 分配的 GPU 列表,根据你的硬件调整 -- `MINERU_START_PORT`: 第一个实例的端口,后续实例递增 - -#### SAM 配置 - -```bash -# SAM Config -SAM_GPUS=(4 5 6) # 使用的 GPU ID -SAM_START_PORT=8021 # 起始端口号 -``` - -**配置说明**: -- `SAM_GPUS`: 分配给 SAM 的 GPU 列表 -- `SAM_START_PORT`: 第一个 SAM 实例的端口 - -### 步骤 2: 启动模型服务器 - -```bash -# 从项目根目录执行 -bash script/start_model_servers.sh -``` - -**启动流程**: -1. 清理旧进程和端口占用 -2. 启动 MinerU 集群(每个 GPU 一个实例) -3. 启动 SAM 集群(每个 GPU 一个实例) -4. 启动负载均衡器(MinerU LB: 8010, SAM LB: 8020) -5. 启动 OCR 服务(端口 8003) - -**日志文件**: -- MinerU: `logs/mineru_gpu{gpu_id}.log` -- SAM: `logs/sam_{gpu_id}.log` -- 负载均衡器: `logs/mineru_lb.log`, `logs/sam_lb.log` -- OCR: `logs/ocr_server.log` - -### 步骤 3: 验证服务运行状态 - -```bash -# 查看所有日志 -tail -f logs/*.log - -# 检查 MinerU 负载均衡器 -curl http://127.0.0.1:8010/health - -# 检查 SAM 负载均衡器 -curl http://127.0.0.1:8020/health - -# 检查 OCR 服务 -curl http://127.0.0.1:8003/health - -# 查看端口占用情况 -lsof -i:8010,8020,8003 -``` - -## 🎯 完整启动流程 - -### 启动顺序 - -按照以下顺序启动整个系统: - -```bash -# 1. 启动模型服务器(如果需要本地模型) -bash script/start_model_servers.sh - -# 2. 启动后端服务 -cd fastapi_app -uvicorn main:app --host 0.0.0.0 --port 8000 --reload - -# 3. 启动前端服务(新终端) -cd frontend-workflow -npm install # 首次运行需要安装依赖 -npm run dev -``` - -### 访问应用 - -- 前端界面: `http://localhost:5173` -- 后端 API 文档: `http://localhost:8000/docs` -- 后端健康检查: `http://localhost:8000/health` - -## ❓ 常见配置问题 - -### 问题 1: API 密钥不匹配 - -**症状**:前端无法连接后端,返回 401 或 403 错误 - -**解决方案**: -```bash -# 检查前端配置 -cat frontend-workflow/.env | grep VITE_API_KEY - -# 检查后端配置(需要在后端代码中查找 API_KEY 配置) -# 确保两者完全一致 -``` - -### 问题 2: 模型配置错误 - -**症状**:工作流运行时报错 "Model not found" 或 "Invalid model name" - -**解决方案**: -1. 检查 Layer 1 是否定义了模型名称 -2. 检查 LLM API URL 是否正确 -3. 验证 API 提供商是否支持该模型 - -```bash -# 测试 LLM API 连接 -curl -X POST http://your-api-url/v1/chat/completions \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer your-api-key" \ - -d '{"model": "gpt-4o", "messages": [{"role": "user", "content": "test"}]}' -``` - -### 问题 3: 端口冲突 - -**症状**:启动模型服务器时报错 "Address already in use" - -**解决方案**: -```bash -# 查找占用端口的进程 -lsof -i:8010 - -# 杀死占用端口的进程 -kill -9 - -# 或者重新运行启动脚本(会自动清理) -bash script/start_model_servers.sh -``` - -### 问题 4: GPU 资源不足 - -**症状**:MinerU 或 SAM 启动失败,日志显示 "CUDA out of memory" - -**解决方案**: -1. 减少 GPU 分配数量 -2. 降低 `MINERU_GPU_UTIL` 参数(如 0.7) -3. 减少 `MINERU_MAX_SEQS` 参数(如 32) - -```bash -# 编辑启动脚本 -vim script/start_model_servers.sh - -# 修改配置 -MINERU_GPU_UTIL=0.7 -MINERU_MAX_SEQS=32 -MINERU_GPUS=(7 1) # 只使用 2 个 GPU -``` - -### 问题 5: Supabase 连接失败 - -**症状**:用户认证功能不可用 - -**解决方案**: -1. 如果不需要用户认证,注释掉所有 Supabase 配置 -2. 如果需要,检查 Supabase URL 和 Key 是否正确 -3. 验证 Supabase 项目是否正常运行 - -## 💡 配置最佳实践 - -### 1. 开发环境 vs 生产环境 - -**开发环境配置**: -```bash -# 使用本地或测试 API -DEFAULT_LLM_API_URL=http://localhost:3000/v1/ - -# 使用较小的模型以节省成本 -PAPER2PPT_DEFAULT_MODEL=gpt-4o-mini -PAPER2PPT_OUTLINE_MODEL=claude-haiku-4-5-20251001 - -# 减少 GPU 资源占用 -MINERU_GPU_UTIL=0.7 -MINERU_GPUS=(0) # 只使用一个 GPU -``` - -**生产环境配置**: -```bash -# 使用稳定的 API 服务 -DEFAULT_LLM_API_URL=https://api.openai.com/v1/ - -# 使用高性能模型 -PAPER2PPT_DEFAULT_MODEL=gpt-5.1 -PAPER2PPT_IMAGE_GEN_MODEL=gemini-3-pro-image-preview - -# 充分利用 GPU 资源 -MINERU_GPU_UTIL=0.85 -MINERU_GPUS=(0 1 2 3) # 使用多个 GPU -``` - -### 2. 模型选择建议 - -**成本优化策略**: -- 大纲生成、技术细节提取:使用 Claude Haiku(成本低) -- 内容生成、图表描述:使用 GPT-4o(平衡性能和成本) -- 图像生成:使用 Gemini Pro(效果好) -- OCR 任务:使用专门的 VLM 模型(qwen-vl-ocr) - -**性能优化策略**: -- 关键任务使用最强模型(GPT-5.1, Claude Opus) -- 并行任务使用不同模型避免 API 限流 -- 图像任务使用专门的多模态模型 - -### 3. 安全配置建议 - -```bash -# ❌ 不要在代码中硬编码密钥 -# ❌ 不要将 .env 文件提交到 Git - -# ✅ 使用环境变量 -export VITE_API_KEY="your-secret-key" - -# ✅ 在 .gitignore 中排除配置文件 -echo ".env" >> .gitignore -echo "*.env" >> .gitignore - -# ✅ 生产环境使用强密钥 -VITE_API_KEY=$(openssl rand -hex 32) -``` - -### 4. 性能优化建议 - -**GPU 资源优化**: -```bash -# 根据 GPU 显存调整参数 -# 24GB GPU: MINERU_GPU_UTIL=0.85, MINERU_MAX_SEQS=64 -# 16GB GPU: MINERU_GPU_UTIL=0.75, MINERU_MAX_SEQS=32 -# 8GB GPU: MINERU_GPU_UTIL=0.65, MINERU_MAX_SEQS=16 -``` - -**并发优化**: -```bash -# OCR 服务 workers 数量根据 CPU 核心数调整 -# 8 核 CPU: --workers 4 -# 16 核 CPU: --workers 8 -# 32 核 CPU: --workers 16 -``` - -### 5. 配置文件管理 - -**推荐的配置文件结构**: -``` -project/ -├── .env.example # 配置模板(提交到 Git) -├── .env # 本地配置(不提交) -├── .env.development # 开发环境配置 -├── .env.production # 生产环境配置 -└── .env.test # 测试环境配置 -``` - -**切换环境**: -```bash -# 开发环境 -cp .env.development .env - -# 生产环境 -cp .env.production .env -``` - -## 📚 相关文档 - -- [安装指南](../installation.md) - 环境搭建和依赖安装 -- [快速开始](../quickstart.md) - 快速体验各项功能 -- [CLI 工具](../cli.md) - 命令行工具使用说明 - -## 🎉 配置完成 - -恭喜!你已经完成了 Paper2Any 项目的配置。现在可以: - -1. 启动模型服务器(如果需要) -2. 启动后端服务 -3. 启动前端服务 -4. 访问 `http://localhost:5173` 开始使用 - -如果遇到问题,请参考上面的常见问题部分,或查看项目的 [FAQ 文档](../faq.md)。 diff --git a/docs/guides/multimodal_api.md b/docs/guides/multimodal_api.md deleted file mode 100644 index 8873727..0000000 --- a/docs/guides/multimodal_api.md +++ /dev/null @@ -1,123 +0,0 @@ -# 多模态供应商与 API 开发指南 - -DataFlow-Agent 采用灵活的策略模式来支持多种多模态 AI 供应商(如 OpenAI DALL-E, Google Gemini, APIYI 等)。本指南将介绍如何扩展系统以支持新的多模态 API。 - -## 核心架构 - -多模态功能的实现主要依赖于以下几个核心组件: - -1. **`VisionLLMCaller`** (`dataflow_agent/llm_callers/image.py`): - * 这是多模态调用的统一入口。 - * 支持 `generation` (生图), `edit` (修图), `understanding` (图像理解), `ocr` (文字识别), `video_understanding` (视频理解) 等模式。 - * 根据配置的 `mode` 调用相应的底层逻辑。 - -2. **`AIProviderStrategy`** (`dataflow_agent/toolkits/multimodaltool/providers.py`): - * 这是一个抽象基类,定义了所有多模态供应商必须实现的接口。 - * 包括请求构建 (`build_request`) 和响应解析 (`parse_response`)。 - -3. **`VLMStrategy`** (`dataflow_agent/agentroles/cores/strategies.py`): - * Agent 执行策略的一种,负责配置和调度 `VisionLLMCaller`。 - -## 接口定义 - -所有的供应商策略都继承自 `AIProviderStrategy`。主要接口如下: - -```python -class AIProviderStrategy(ABC): - - @abstractmethod - def match(self, api_url: str, model: str) -> bool: - """判断当前策略是否适用于给定的 API URL 和模型名称""" - pass - - @abstractmethod - def build_generation_request(self, api_url, model, prompt, **kwargs) -> Tuple[str, Dict, bool]: - """构造文生图请求。返回: (url, payload, is_stream)""" - pass - - @abstractmethod - def parse_generation_response(self, response_data: Dict) -> str: - """解析生图响应,返回 Base64 图片字符串""" - pass - - # 可选实现:图生图/编辑 - def build_edit_request(self, ...): ... - - # 可选实现:多图编辑 - def build_multi_image_edit_request(self, ...): ... - - # 可选实现:TTS (语音合成) - def build_tts_request(self, ...): ... - - # 可选实现:对话/多模态理解 (默认实现了 OpenAI 兼容格式) - def build_chat_request(self, ...): ... - def parse_chat_response(self, ...): ... -``` - -## 如何添加新供应商 - -要添加一个新的多模态供应商(例如 "MyNewAI"),请按照以下步骤操作: - -### 步骤 1: 创建策略类 - -在 `dataflow_agent/toolkits/multimodaltool/providers.py` 中定义一个新的类,继承自 `AIProviderStrategy`。 - -```python -class MyNewAIProvider(AIProviderStrategy): - """ - MyNewAI 服务商支持 - """ - def match(self, api_url: str, model: str) -> bool: - # 定义匹配规则,例如检查 URL 或模型前缀 - return "mynewai.com" in api_url or model.startswith("mynewai-") - - def build_generation_request(self, api_url: str, model: str, prompt: str, **kwargs) -> Tuple[str, Dict[str, Any], bool]: - # 构造请求 - url = f"{api_url.rstrip('/')}/v1/images/generate" - payload = { - "model": model, - "text": prompt, - "width": 1024, - "height": 1024 - } - is_stream = False - return url, payload, is_stream - - def parse_generation_response(self, data: Dict[str, Any]) -> str: - # 解析响应,提取图片 Base64 - if "data" in data and "image_base64" in data["data"]: - return data["data"]["image_base64"] - raise RuntimeError("Failed to parse MyNewAI response") -``` - -### 步骤 2: 实现高级功能 (可选) - -如果该供应商支持图生图 (Image Editing) 或多模态理解,请重写相应的方法: - -* `build_edit_request`: 用于处理图像编辑任务。 -* `build_chat_request` / `parse_chat_response`: 用于处理 Vision/Chat 任务(如果 API 格式不兼容 OpenAI 标准)。 - -### 步骤 3: 注册策略 - -在 `dataflow_agent/toolkits/multimodaltool/providers.py` 文件底部的 `STRATEGIES` 列表中注册你的新策略类。**注意顺序很重要**,系统会按顺序尝试匹配。 - -```python -STRATEGIES = [ - ApiYiGeminiProvider(), - MyNewAIProvider(), # <--- 添加在这里 - ApiYiSeeDreamProvider(), - # ... - OpenAICompatGeminiProvider(), # 默认回退 -] -``` - -## 调试与测试 - -1. 配置环境变量 `DF_API_URL` 和 `DF_API_KEY` 指向你的新服务商。 -2. 使用 `VisionLLMCaller` 进行测试,或者直接运行 `dataflow_agent/llm_callers/image.py` 中的 `_quick_test` 函数(需要稍作修改以支持你的模型参数)。 -3. 检查 `dataflow_agent.logger` 输出的日志以排查请求构建或响应解析的问题。 - -## 常见问题 - -* **Multipart Upload**: 如果编辑接口需要上传文件(`multipart/form-data`),在 `build_edit_request` 中返回的 payload 应包含 `{"__is_multipart__": True, "files": ..., "data": ...}`。参考 `ApiYiGPTImageProvider` 的实现。 -* **流式响应**: 如果 API 返回流式数据 (SSE),`build_*_request` 返回的第三个参数 `is_stream` 应设为 `True`。 diff --git a/docs/guides/paper2figure.md b/docs/guides/paper2figure.md deleted file mode 100644 index 2db0709..0000000 --- a/docs/guides/paper2figure.md +++ /dev/null @@ -1,11 +0,0 @@ -# Paper2Figure 指南 - -> **Paper2Figure** 专注于从论文内容一键生成科研绘图,包括模型架构图、实验对比图等。 - -## 功能介绍 - -* **模型架构图生成**:自动分析论文中的模型描述,生成可视化的架构图。 -* **实验数据图生成**:提取实验数据表格,生成对比柱状图、折线图等。 -* **可编辑性**:生成的图表支持 PPTX 格式导出,方便二次编辑。 - -*(详细文档正在编写中...)* diff --git a/docs/guides/paper2ppt.md b/docs/guides/paper2ppt.md deleted file mode 100644 index 3f8fa81..0000000 --- a/docs/guides/paper2ppt.md +++ /dev/null @@ -1,12 +0,0 @@ -# Paper2PPT 指南 - -> **Paper2PPT** 能够将学术论文快速转换为结构化的演示文稿 (PPT)。 - -## 功能介绍 - -* **智能大纲生成**:自动提取论文核心观点,生成 PPT 大纲。 -* **长文档支持**:支持处理超长论文,自动分页与总结。 -* **图表提取与插入**:自动识别并提取论文中的图片和表格,插入到对应的 PPT 页面中。 -* **多种模板**:支持 Beamer 风格及多种自定义 PPT 模板。 - -*(详细文档正在编写中...)* diff --git a/docs/guides/paper2technical.md b/docs/guides/paper2technical.md deleted file mode 100644 index 2a2243b..0000000 --- a/docs/guides/paper2technical.md +++ /dev/null @@ -1,11 +0,0 @@ -# Paper2Technical 指南 - -> **Paper2Technical** 专注于生成技术路线图 (Technical Roadmap)。 - -## 功能介绍 - -* **流程分析**:深入分析算法流程或系统架构。 -* **路线图生成**:生成清晰的步骤图、流程图或技术演进路线图。 -* **SVG/PPTX 输出**:支持矢量格式输出,满足高质量发表需求。 - -*(详细文档正在编写中...)* diff --git a/docs/guides/paper2video.md b/docs/guides/paper2video.md deleted file mode 100644 index 5099ba0..0000000 --- a/docs/guides/paper2video.md +++ /dev/null @@ -1,11 +0,0 @@ -# Paper2Video 指南 - -> **Paper2Video** 旨在辅助研究人员将论文内容转换为视频脚本或演示视频。 - -## 功能介绍 - -* **视频脚本生成**:根据论文结构生成适合视频讲解的脚本 (Script)。 -* **关键帧规划**:建议视频各阶段的视觉画面或关键帧。 -* **多模态合成** (开发中):结合 TTS 和图像生成,自动合成演示视频。 - -*(详细文档正在编写中...)* diff --git a/docs/index.md b/docs/index.md deleted file mode 100644 index 0f50c13..0000000 --- a/docs/index.md +++ /dev/null @@ -1,210 +0,0 @@ -# Paper2Any 项目文档 - -
- -**从论文到多模态输出的智能化工作流平台** - - - -
- ---- - -## 💡 项目简介 - -**Paper2Any** 是一个基于深度学习的智能化工作流平台,专注于将学术论文转换为多种形式的输出,包括示意图、PPT、视频、技术报告等。通过集成最新的多模态大模型和计算机视觉技术,Paper2Any 能够自动解析论文内容并生成高质量的视觉和文本输出。 - -### 核心优势 - -- 🎯 **多模态输出**:支持从论文生成示意图(Figure)、PPT、视频(Video)、技术报告(Technical Report)等多种格式 -- 🔌 **模块化设计**:基于 DataFlow-Agent 框架,工作流可灵活组合和扩展 -- 🎨 **高质量生成**:集成前沿的视觉生成模型和文本生成模型,确保输出质量 -- ⚡ **高效处理**:支持批量处理和并行计算,快速处理大量论文 -- 🔄 **灵活部署**:提供 Docker 容器化部署和本地部署选项 - ---- - -## ✨ 核心功能 - -### 📊 Paper2Figure -从论文中提取关键信息,自动生成高质量的示意图和图表,支持学术演示和论文插图需求。 - -### 📽️ Paper2PPT -基于论文内容自动生成结构化的 PowerPoint 演示文稿,包括封面、目录、内容页和参考文献页。 - -### 🎬 Paper2Video -将论文内容转换为讲解视频,自动生成脚本、配音和视觉内容,适合快速了解论文核心思想。 - -### 📝 Paper2Technical -提取论文的技术细节,生成详细的技术报告、方法描述和实现指南。 - -### 🔧 其他功能 -- **PDF2PPT**:将现有的PDF文件转换为可编辑的PPT演示文稿 -- **Paper2ExpFigure**:为论文生成实验数据图表 -- **Paper2PageContent**:提取论文页面内容,用于知识库构建 - ---- - -## 🚀 快速开始 - -### 环境要求 - -- **Python**: 3.10 或更高版本([下载 Python](https://www.python.org/downloads/)) -- **操作系统**: Linux (推荐) / Windows / macOS -- **GPU**: 推荐 NVIDIA GPU(用于视觉生成任务) -- **内存**: 至少 16GB RAM - -### 安装步骤 - -#### 1. 克隆仓库 - -```bash -git clone https://github.com/OpenDCAI/Paper2Any.git -cd Paper2Any -``` - -#### 2. 创建虚拟环境(推荐) - -```bash -# 使用 venv -python -m venv venv -source venv/bin/activate # Windows: venv\Scripts\activate - -# 或使用 conda -conda create -n paper2any python=3.10 -conda activate paper2any -``` - -#### 3. 安装依赖 - -```bash -# 安装基础依赖 -pip install -r requirements-base.txt - -# 安装开发依赖(可选) -pip install -r requirements-dev.txt - -# 安装Paper2Any包 -pip install -e . -``` - -#### 4. 配置模型服务 - -某些功能需要运行额外的模型服务。请参考[安装指南](installation.md)的详细说明。 - -#### 5. 启动应用 - -```bash -# 启动 Gradio Web 界面(推荐用于测试) -python gradio_app/app.py -``` - -访问 **http://127.0.0.1:7860** 使用可视化界面。 - -或者使用 FastAPI 后端: - -```bash -# 启动 FastAPI 后端 -cd fastapi_app -uvicorn main:app --host 0.0.0.0 --port 8000 --reload -``` - ---- - -## 📖 文档导航 - -- **[快速开始](quickstart.md)** - 新手入门指南 -- **[安装指南](installation.md)** - 详细安装和配置说明 -- **[功能指南](guides/)** - 各功能模块的详细使用说明 - - [Paper2Figure](guides/paper2figure.md) - - [Paper2PPT](guides/paper2ppt.md) - - [Paper2Video](guides/paper2video.md) - - [Paper2Technical](guides/paper2technical.md) -- **[CLI工具](cli.md)** - 命令行工具使用说明 -- **[常见问题解答](faq.md)** - 常见问题解决方法 -- **[贡献指南](contributing.md)** - 参与项目开发的指南 -- **[更新日志](changelog.md)** - 版本更新记录 - ---- - -## 🏗️ 系统架构 - -``` -Paper2Any/ -├── dataflow_agent/ # 底层工作流引擎 -│ ├── agentroles/ # Agent 角色定义 -│ ├── workflow/ # 工作流定义 (wf_*.py) -│ ├── toolkits/ # 工具集 -│ └── ... -├── fastapi_app/ # FastAPI 后端服务 -│ ├── routers/ # API 路由 -│ ├── workflow_adapters/ # 工作流适配器 -│ └── ... -├── gradio_app/ # Gradio Web 界面 -│ ├── app.py # 主应用入口 -│ └── pages/ # 页面模块 -├── frontend-workflow/ # 前端界面 (Vite + TypeScript) -├── script/ # 运行脚本 -├── docs/ # 项目文档 -├── tests/ # 测试文件 -└── outputs/ # 输出目录 -``` - - ---- - -## 🤝 参与贡献 - -我们欢迎任何形式的贡献!无论是提交 Bug、提出新功能建议,还是改进文档。 - -### 贡献流程 - -1. **Fork 本仓库**并克隆到本地 -2. **创建功能分支**: `git checkout -b feature/amazing-feature` -3. **提交代码**: `git commit -m 'Add amazing feature'` -4. **推送到分支**: `git push origin feature/amazing-feature` -5. **提交 Pull Request** - -### 代码规范 - -- 遵循 PEP 8 Python 代码风格 -- 为新功能添加单元测试 -- 更新相关文档(包括 docstring 和 MkDocs 文档) -- 提交信息清晰描述变更内容 - -详见 [贡献指南](contributing.md)。 - ---- - -## 📄 开源协议 - -本项目采用 **Apache License 2.0** 开源协议。详情请查看 [LICENSE](LICENSE) 文件。 - ---- - -## 🙏 致谢 - -感谢所有为本项目做出贡献的开发者和使用者! - -特别鸣谢: -- [DataFlow-Agent](https://github.com/OpenDCAI/Paper2Any) - 底层工作流框架 -- [Gradio](https://gradio.app/) - 优秀的 Web 界面框架 -- [FastAPI](https://fastapi.tiangolo.com/) - 高性能 API 框架 -- [LangGraph](https://github.com/langchain-ai/langgraph) - 工作流编排灵感来源 - ---- - -## 📞 联系我们 - -- **问题反馈**: [GitHub Issues](https://github.com/OpenDCAI/Paper2Any/issues) -- **讨论交流**: [GitHub Discussions](https://github.com/OpenDCAI/Paper2Any/discussions) - ---- - -
- -**如果这个项目对你有帮助,请给我们一个 ⭐️ Star!** - -Made with ❤️ by Paper2Any Team - -
diff --git a/docs/installation.md b/docs/installation.md deleted file mode 100644 index 5fba14c..0000000 --- a/docs/installation.md +++ /dev/null @@ -1,268 +0,0 @@ -# 安装指南 - -本指南将帮助您完成 Paper2Any 的安装和环境配置。 - -## 环境要求 - -### 系统要求 -- **操作系统**: Linux (推荐), Windows 10/11, macOS 10.15+ -- **Python**: 3.10 或更高版本 -- **内存**: 至少 16GB RAM(推荐 32GB+ 用于大模型推理) -- **存储**: 至少 50GB 可用空间(用于模型缓存和输出文件) -- **GPU**: 可选但推荐(用于加速视觉生成任务) - - NVIDIA GPU(支持 CUDA 11.8+) - - 至少 8GB 显存(推荐 16GB+) - -### 网络要求 -- 稳定的互联网连接(用于下载模型和依赖) -- 能够访问 GitHub, PyPI, HuggingFace - -## 安装步骤 - -### 1. 克隆仓库 - -```bash -git clone https://github.com/OpenDCAI/Paper2Any.git -cd Paper2Any -``` - -### 2. 创建虚拟环境(推荐) - -#### 使用 venv (Python 内置) -```bash -python -m venv venv - -# Linux/macOS -source venv/bin/activate - -# Windows -venv\Scripts\activate -``` - -#### 使用 conda -```bash -conda create -n paper2any python=3.10 -conda activate paper2any -``` - -### 3. 安装基础依赖 - -Paper2Any 提供了多个依赖文件以适应不同场景: - -```bash -# 安装核心依赖(必须) -pip install -r requirements-base.txt - -# 安装开发依赖(推荐,包含测试和工具) -pip install -r requirements-dev.txt - -# 安装 Paper2Any 包(开发模式) -pip install -e . -``` - -#### Windows 用户注意事项 -Windows 用户可以使用 `requirements-win-base.txt` 替代 `requirements-base.txt`: - -```bash -pip install -r requirements-win-base.txt -``` - -### 4. 模型服务配置 - -Paper2Any 依赖多个外部模型服务来完成各种任务。您需要配置以下服务: - -#### 4.1 文本生成模型(必需) -Paper2Any 需要 LLM 服务来处理文本生成任务。您有以下选择: - -**选项 A:使用本地部署的模型服务** -```bash -# 示例:使用 Ollama 部署本地模型 -ollama pull qwen2.5:7b -ollama serve -``` - -**选项 B:使用云 API 服务** -- OpenAI GPT 系列(需 API Key) -- 阿里云通义千问(需 API Key) -- DeepSeek(需 API Key) - -在 `.env` 文件中配置 API 信息: -```bash -# 复制示例配置文件 -cp .env.example .env - -# 编辑 .env 文件,填入您的 API 配置 -LLM_API_URL="https://api.openai.com/v1" -LLM_API_KEY="your-api-key" -LLM_MODEL="gpt-4o" -``` - -#### 4.2 图像生成模型(可选,用于 Paper2Figure/Paper2PPT) -如果使用图像生成功能,需要配置: - -- **Stable Diffusion** 或 **DALL-E** API -- 或部署本地 SD WebUI - -相关配置可参考 `script/start_model_servers.sh` - -#### 4.3 其他模型服务 -- **OCR 服务**: 用于提取 PDF 文本(如 PaddleOCR) -- **语音合成**: 用于 Paper2Video(如 Edge-TTS) -- **视频生成**: 用于 Paper2Video(如 Stable Video Diffusion) - -### 5. 数据库配置(可选) - -Paper2Any 使用 SQLite 作为默认数据库。如需使用其他数据库: - -#### SQLite(默认) -无需额外配置,首次运行会自动创建数据库。 - -#### PostgreSQL -1. 安装 PostgreSQL 和 psycopg2: - ```bash - pip install psycopg2-binary - ``` -2. 创建数据库: - ```sql - CREATE DATABASE paper2any; - ``` -3. 在 `.env` 中配置连接字符串: - ``` - DATABASE_URL=postgresql://user:password@localhost:5432/paper2any - ``` - -### 6. 验证安装 - -运行以下命令验证安装是否成功: - -```bash -# 运行简单测试 -python -c "import dataflow_agent; print('DataFlow-Agent installed successfully')" - -# 测试 Paper2Any 工作流 -python script/run_paper2figure.py --help -``` - -如果看到帮助信息,说明安装成功。 - -## Docker 安装(推荐用于生产环境) - -### 使用 Docker Compose(最简单) - -```bash -# 启动所有服务(包括数据库和模型服务) -docker-compose up -d - -# 查看服务状态 -docker-compose ps - -# 查看日志 -docker-compose logs -f -``` - -### 自定义 Docker 构建 - -```bash -# 构建镜像 -docker build -t paper2any:latest . - -# 运行容器 -docker run -p 7860:7860 -p 8000:8000 \ - -v $(pwd)/outputs:/app/outputs \ - -v $(pwd)/models:/app/models \ - paper2any:latest -``` - -### Docker Compose 配置文件 - -项目提供的 `docker-compose.yml` 包含三个主要服务: - -1. **paper2any-app**: 主应用服务(Gradio + FastAPI) -2. **model-server**: 模型推理服务(可配置) -3. **postgres**: 数据库服务(可选) - -## 开发环境配置 - -### IDE 配置 - -推荐使用 VS Code 或 PyCharm 作为开发环境: - -#### VS Code 配置 -1. 安装 Python 扩展 -2. 配置工作区设置: - ```json - { - "python.defaultInterpreterPath": "./venv/bin/python", - "python.linting.enabled": true, - "python.linting.pylintEnabled": true, - "python.formatting.provider": "black" - } - ``` - -#### PyCharm 配置 -1. 设置虚拟环境解释器 -2. 启用自动代码格式化 -3. 配置运行/调试配置 - -### 预提交钩子(代码质量) - -```bash -# 安装预提交钩子 -pre-commit install - -# 手动运行检查 -pre-commit run --all-files -``` - -## 故障排除 - -### 常见问题 - -#### 1. 依赖安装失败 -- **问题**: `pip install` 失败,提示版本冲突 -- **解决**: 使用虚拟环境,或尝试: - ```bash - pip install --upgrade pip setuptools wheel - pip install -r requirements-base.txt --no-deps - ``` - -#### 2. CUDA 相关错误 -- **问题**: 无法导入 torch 或 tensorflow -- **解决**: 确保安装正确版本的 CUDA 工具包: - ```bash - # 查看 CUDA 版本 - nvcc --version - - # 安装对应版本的 PyTorch - pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 - ``` - -#### 3. 模型下载慢 -- **问题**: 下载 HuggingFace 模型速度慢 -- **解决**: 使用镜像源: - ```bash - export HF_ENDPOINT=https://hf-mirror.com - ``` - -#### 4. 内存不足 -- **问题**: 运行大模型时内存溢出 -- **解决**: - - 使用 CPU 模式(性能较低) - - 增加交换空间 - - 使用量化模型 - -### 获取帮助 - -如果遇到无法解决的问题: - -1. 查看项目的 [GitHub Issues](https://github.com/OpenDCAI/Paper2Any/issues) -2. 搜索类似问题的解决方案 -3. 提交新的 Issue(包含详细错误信息) - -## 下一步 - -安装完成后,请继续: - -- 📖 [快速开始](quickstart.md) - 学习如何使用 Paper2Any 的基本功能 -- 🛠️ [功能指南](guides/) - 深入了解各功能模块 -- 🐳 [部署指南](guides/deployment.md) - 学习如何部署到生产环境 diff --git a/docs/quickstart.md b/docs/quickstart.md deleted file mode 100644 index d3d1436..0000000 --- a/docs/quickstart.md +++ /dev/null @@ -1,210 +0,0 @@ -# 🚀 Paper2Any 快速开始指南 - -本指南帮助您快速上手 Paper2Any 的核心功能。在完成 [安装指南](./installation.md) 后,您可以通过以下方式快速体验 Paper2Any 的强大能力。 - -## 📊 快速体验 Paper2Figure:科研绘图 - -Paper2Figure 支持三种主要绘图模式:模型架构图、技术路线图、实验数据图。 - -### 方式一:命令行快速生成(推荐) - -1. **模型架构图生成**: - ```bash - python script/run_paper2figure.py --input "tests/2506.02454v1.pdf" --model architecture - ``` - -2. **技术路线图生成**: - ```bash - python script/run_paper2technical.py --input "tests/2506.02454v1.pdf" - ``` - -3. **实验数据图生成**: - ```bash - python script/run_paper2expfigure.py --input "tests/2506.02454v1.pdf" - ``` - -### 方式二:Web界面交互体验 - -1. **启动后端服务**: - ```bash - cd fastapi_app - uvicorn main:app --host 0.0.0.0 --port 8000 - ``` - -2. **启动前端界面**: - ```bash - cd frontend-workflow - npm run dev - ``` - -3. **访问界面**: - 打开浏览器访问 `http://localhost:3000`,选择"Paper2Figure"功能模块,上传论文PDF或输入文本即可快速生成。 - -## 🎬 快速体验 Paper2PPT:论文转演示文稿 - -### 方式一:命令行生成 - -```bash -# 从论文PDF生成PPT -python script/run_paper2ppt.py --input "tests/2506.02454v1.pdf" - -# 从文本生成PPT -python script/run_paper2ppt.py --text "深度学习在医疗影像分析中的应用" - -# 生成长文档PPT(40+页) -python script/run_paper2ppt.py --input "long_paper.pdf" --long_doc -``` - -### 方式二:Web界面使用 - -1. 确保后端和前端服务正在运行(同上) -2. 访问 `http://localhost:3000`,选择"Paper2PPT"功能模块 -3. 上传论文PDF或输入主题,选择风格模板,点击生成 - -## 🖼️ 快速体验 PDF2PPT:保持版式的PDF转换 - -### 方式一:命令行生成 - -```bash -# 基本转换 -python script/run_pdf2ppt_parallel.py --input "tests/test_02.pdf" - -# 使用MinerU优化版 -python script/run_pdf2ppt_with_paddle_sam_mineru.py --input "tests/test_02.pdf" -``` - -### 方式二:Web界面使用 - -1. 访问 `http://localhost:3000`,选择"PDF2PPT"功能模块 -2. 上传PDF文件,系统会自动进行智能抠图和版式分析 -3. 下载可编辑的PPTX文件 - -## 🎨 快速体验 Image2PPT:图片转演示文稿 - -### 方式一:命令行生成 - -```bash -python script/run_image2ppt.py --image "tests/test_02.png" -``` - -### 方式二:Web界面使用 - -1. 访问 `http://localhost:3000`,选择"Image2PPT"功能模块 -2. 上传图片文件(支持PNG、JPG、JPEG格式) -3. 系统会自动分析图片内容并生成PPT - -## ⚡ 快速脚本说明 - -Paper2Any 提供了多个快速脚本,位于 `script/` 目录下: - -| 脚本文件 | 功能 | 常用参数 | -|---------|------|----------| -| `run_paper2figure.py` | 模型架构图生成 | `--input`, `--model`, `--output_dir` | -| `run_paper2technical.py` | 技术路线图生成 | `--input`, `--style`, `--output_dir` | -| `run_paper2expfigure.py` | 实验数据图生成 | `--input`, `--chart_type`, `--output_dir` | -| `run_paper2ppt.py` | 论文转PPT | `--input`, `--text`, `--long_doc`, `--output` | -| `run_pdf2ppt_parallel.py` | PDF转PPT(并行版) | `--input`, `--output`, `--workers` | -| `run_pdf2ppt_with_paddle_sam_mineru.py` | PDF转PPT(优化版) | `--input`, `--output`, `--gpu_id` | -| `run_image2ppt.py` | 图片转PPT | `--image`, `--output` | - -## 🔧 配置说明 - -### 环境变量配置 - -在运行前,请确保已配置必要的环境变量: - -```bash -# API密钥配置 -export DF_API_KEY="your_api_key_here" - -# 可选:自定义API端点 -export DF_API_URL="http://your-api-gateway/v1/" - -# 可选:MinerU GPU资源配置 -export MINERU_DEVICES="0,1" # 使用GPU 0和1 -``` - -### Supabase配置(Web功能必需) - -在 `frontend-workflow/.env` 文件中配置: - -```bash -VITE_SUPABASE_URL=your_supabase_url -VITE_SUPABASE_ANON_KEY=your_supabase_anon_key -SUPABASE_URL=your_supabase_url -SUPABASE_ANON_KEY=your_supabase_anon_key -SUPABASE_SERVICE_ROLE_KEY=your_service_role_key -SUPABASE_JWT_SECRET=your_jwt_secret -``` - -## 📁 输出文件说明 - -所有生成的输出文件默认保存在以下目录: - -- **Paper2Figure输出**:`outputs/paper2fig_ppt/{timestamp}/` - - `ppt_pages/`:PPT页面图片 - - `clean_backgrounds/`:去背景后的图片 - - `final_output.pptx`:最终PPT文件 - -- **Paper2PPT输出**:`outputs/paper2ppt/{timestamp}/` - - `ppt_pages/`:PPT页面 - - `final_output.pptx`:最终PPT文件 - -- **PDF2PPT输出**:`outputs/pdf2ppt/{timestamp}/` - - `pages/`:处理后的页面 - - `final_output.pptx`:最终PPT文件 - -## 🐳 Docker快速体验 - -如果您不想在本地安装环境,可以使用Docker快速体验: - -```bash -# 克隆项目 -git clone https://github.com/OpenDCAI/Paper2Any.git -cd Paper2Any - -# 启动所有服务 -docker-compose up -d - -# 访问Web界面 -# 前端:http://localhost:3000 -# 后端API:http://localhost:8000 -``` - -## ❓ 常见问题 - -### Q1: 运行时提示缺少依赖? -A: 请确保已按照 [安装指南](./installation.md) 安装了所有依赖,特别是LaTeX引擎(tectonic)和系统工具(Inkscape、LibreOffice)。 - -### Q2: 生成速度慢怎么办? -A: 可以尝试以下优化: -1. 使用 `--workers` 参数并行处理(如果脚本支持) -2. 确保已正确配置GPU资源(对于MinerU等需要GPU的组件) -3. 调整模型服务配置,减少等待时间 - -### Q3: 如何自定义生成风格? -A: 大多数脚本支持 `--style` 或 `--template` 参数,可以指定不同的生成风格。您也可以修改 `dataflow_agent/promptstemplates/` 中的提示词模板来自定义风格。 - -### Q4: 生成的PPT无法编辑? -A: Paper2Any 生成的PPT是完全可编辑的PPTX格式。如果遇到问题,请确保: -1. 使用最新版本的Microsoft PowerPoint或LibreOffice -2. 检查文件扩展名是否为 `.pptx` -3. 尝试使用脚本的 `--output_format pptx` 参数(如果支持) - -## 📚 下一步 - -- 查看 [详细功能指南](../docs/guides/) 了解每个功能的深度用法 -- 了解 [系统架构](../docs/index.md#系统架构) 理解内部工作原理 -- 参考 [API文档](../docs/api/) 进行二次开发 -- 参与 [贡献指南](../docs/contributing.md) 帮助改进项目 - -## 🆘 获取帮助 - -如果在使用过程中遇到问题: -1. 查看 [FAQ](./faq.md) 寻找常见问题解答 -2. 提交 [GitHub Issue](https://github.com/OpenDCAI/Paper2Any/issues) -3. 加入 [微信社群](../README.md#wechat-group) 获取实时帮助 - ---- - -**开始您的Paper2Any之旅吧!** 🎉 diff --git a/docs/tech_route_template_palette_react.md b/docs/tech_route_template_palette_react.md deleted file mode 100644 index 9dd1eab..0000000 --- a/docs/tech_route_template_palette_react.md +++ /dev/null @@ -1,272 +0,0 @@ -# 技术路线图:模板参考 + 可选配色 + SVG 渲染校验(ReAct重试) - -本文档记录本次开发任务的:需求、具体设计、开发计划、当前开发进度与后续待办。 - ---- - -## 1. 背景与现状 - -现有 `paper2technical` 工作流会生成一份技术路线图 SVG,并同时渲染 PNG,前端提供: -- SVG 源文件下载按钮 -- SVG 图片(PNG 预览)URL - -当前问题/目标: -- 生成 SVG 过程中偶发语法错误或渲染失败(导致显示乱码/无法预览)。 -- 需要支持“可选配色”:不选配色只输出黑白;选配色同时输出黑白+彩色两套,并在前端提供两套下载/链接。 -- 生成 SVG 的整体结构/样式需参考一张固定模板图(`temp.png`),但允许根据用户内容微调排版。 - -> 备注:本次 ReAct 的目的仅用于 **SVG 语法正确性/可渲染性**,不用于语义完整性检查。 - ---- - -## 2. 需求(确认版) - -### 2.1 输出规则 - -1) **未选择配色**(`tech_route_palette == ""`): -- 仅生成黑白 SVG(以及对应 PNG)。 -- 仍使用原有字段:`svg_filename/svg_image_filename` 指向黑白版本。 - -2) **选择配色**(`tech_route_palette != ""`): -- 同时生成黑白 SVG + 彩色 SVG(以及各自 PNG)。 -- `svg_filename/svg_image_filename` 继续指向黑白(兼容旧逻辑)。 -- 新增两套字段用于彩色: - - `svg_color_filename/svg_color_image_filename` - - 同时也新增黑白明确字段:`svg_bw_filename/svg_bw_image_filename` -- 前端展示两套下载按钮与 URL。 - -### 2.2 ReAct 诉求(仅渲染正确性) - -黑白 SVG 生成后需要验证: -- SVG XML 结构合法(可解析) -- 必须有 `viewBox` -- 必须能通过 CairoSVG 渲染为 PNG(`local_tool_for_svg_render` 成功) - -若验证失败,则应自动重试(最多 N 次),并把失败原因反馈给模型以修复。 - -> 注意:由于不能修改 `BaseAgent`,这里的“ReAct”以 **workflow 节点内的循环重试 + validator** 形式实现。 - -### 2.3 模板与上色要求 - -- 模板:使用仓库内 `temp.png` 作为技术路线图模板,生成时需参考该图的格式/层级/布局风格,可微调排版以适配实际内容。 -- 上色:同类型或同层级内容最好使用同一颜色(例如同一个阶段的节点一致、箭头一致、文字一致)。 -- 色卡:提供 3–4 套预设,每套约 3–4 种颜色,并在前端可预览色卡颜色。 - ---- - -## 3. 具体设计 - -### 3.1 模板文件与默认路径 - -- 模板文件落位: - - `static/paper2any_imgs/p2t/temp.png` -- Workflow 默认模板路径: - - 优先 `request.tech_route_template`(支持传入 `temp.png` 或绝对路径) - - 为空时使用默认 `static/paper2any_imgs/p2t/temp.png` - -### 3.2 色卡设计(预设 4 套,每套 4 色,可视化) - -色卡字段含义(用于上色 agent): -- `colors`: 提供给模型可用颜色集合(4个 hex) -- `level_colors`: “同层级/同阶段颜色”的推荐序列(4个 hex) -- `arrow_color`: 箭头/连线强调色 -- `text_color`: 文本颜色(确保可读性) - -目前内置色卡(与前端保持一致): -- `academic_blue`: `#1F6FEB #60A5FA #A7C7FF #0B3D91` -- `teal_orange`: `#0F766E #14B8A6 #F59E0B #FB923C` -- `slate_rose`: `#334155 #64748B #F43F5E #FCA5A5` -- `indigo_amber`: `#4338CA #6366F1 #F59E0B #FCD34D` - -### 3.3 Workflow 设计(不改 BaseAgent) - -工作流:`dataflow_agent/workflow/wf_paper2technical.py` - -整体流程: -1) `_start_`:初始化 `result_path` -2) `paper_idea_extractor`(PDF 模式才走) -3) `technical_route_bw_svg_generator`:黑白 SVG(VLM,参考模板 PNG) -4) **条件分支**: - - 未选配色:直接 `technical_ppt_generator` - - 选配色:`technical_route_colorize_svg` → `technical_ppt_generator` -5) `technical_ppt_generator`:若选配色则 PPT 插彩色,否则插黑白 - -关键点:黑白/上色两个节点都实现“手写 ReAct 重试”: -- 每次调用 agent 得到 SVG -- 做校验(XML + viewBox + CairoSVG 渲染) -- 不通过则把错误写入 `state.temp_data["validation_feedback"]`,进入下一次尝试 - -黑白产物写入: -- `state.figure_tec_svg_bw_content` -- `state.svg_bw_file_path / state.svg_bw_img_path` -- 同时为了兼容旧字段:`state.svg_file_path / state.svg_img_path` 指向黑白 - -彩色产物写入: -- `state.figure_tec_svg_color_content` -- `state.svg_color_file_path / state.svg_color_img_path` - -### 3.4 Agent 与 Prompt 设计(尽量简化) - -#### 3.4.1 黑白 SVG 生成(参考模板 PNG) - -- Agent:`technical_route_bw_svg_generator` - - 文件:`dataflow_agent/agentroles/paper2any_agents/technical_route_bw_svg_generator.py` - - 使用 VLM:`use_vlm=True`,`vlm_config.mode="understanding"`,`vlm_config.input_image=模板路径` - - 固定模型:`gpt-5.2` - - 输出:严格 JSON `{"svg_code":"..."}`,写入 `state.figure_tec_svg_bw_content` - -- Prompt(仅保留关键约束): - - 参考模板图结构/排版 - - 根据 `paper_idea` 生成内容 - - 黑白/灰度(限制 fill/stroke) - - `viewBox` 必须存在 - - 如有 `validation_feedback` 则修复 - -#### 3.4.2 彩色 SVG 上色(仅输入黑白SVG + 色卡) - -- Agent:`technical_route_colorize_svg` - - 文件:`dataflow_agent/agentroles/paper2any_agents/technical_route_colorize_svg_agent.py` - - 文本模式(非 VLM) - - 固定模型:`gpt-5.2` - - 输入:`bw_svg_code` + `palette_json` + `validation_feedback` - - 输出:严格 JSON `{"svg_code":"..."}`,写入 `state.figure_tec_svg_color_content` - -- Prompt(仅保留关键约束): - - 不改几何结构/坐标/path d/文字内容,只改 fill/stroke/style/class - - 同层级同色、同类型同色(引导使用 `level_colors`) - - 箭头统一 `arrow_color`,文字统一 `text_color` - - 如有 `validation_feedback` 则修复 - -### 3.5 前后端返回与兼容策略 - -后端响应模型:`fastapi_app/schemas.py` -- 兼容字段(始终存在/可能为空): - - `svg_filename/svg_image_filename`:**黑白**(即使选配色也指向黑白) -- 新增字段(选配色时返回): - - `svg_bw_filename/svg_bw_image_filename`:黑白(更明确) - - `svg_color_filename/svg_color_image_filename`:彩色 - -前端展示策略: -- 未选配色:只显示黑白下载与链接 -- 选配色:显示黑白 + 彩色两组下载与链接 -- 色卡选择:`SettingsCard` 中新增下拉与颜色圆点预览 - ---- - -## 4. 开发计划(执行拆解) - -1) **模板与配色参数接入** - - request:`tech_route_template`、`tech_route_palette` - - 默认模板:`static/paper2any_imgs/p2t/temp.png` - -2) **黑白 SVG 节点** - - 新增 VLM agent + prompt - - workflow 内实现渲染校验重试(XML+viewBox+CairoSVG) - - 输出黑白 svg/png 并写回 state - -3) **可选上色节点** - - 新增上色 agent + prompt - - workflow 内实现渲染校验重试 - - 输出彩色 svg/png 并写回 state - -4) **后端双套字段返回** - - adapter 取出 state 中的 bw/color 路径并返回 - - service 层将路径转换为 outputs URL - -5) **前端:色卡下拉 + 两套下载/URL** - - 常量定义色卡列表 + 颜色预览 - - formData 透传 `tech_route_palette` + `tech_route_template` - - 页面展示黑白/彩色两套链接 - ---- - -## 5. 当前开发进度(已完成项) - -### 5.1 模板落位 -- 已将仓库根目录 `temp.png` 复制到:`static/paper2any_imgs/p2t/temp.png` - -### 5.2 后端接口与返回字段 -- `fastapi_app/schemas.py` - - `Paper2FigureRequest` 增加:`tech_route_template`, `tech_route_palette` - - `Paper2FigureResponse` 增加:`svg_bw_*`, `svg_color_*` -- `fastapi_app/routers/paper2any.py` - - `generate_paper2figure_json` 增加 Form 参数并传入 service -- `fastapi_app/services/paper2any_service.py` - - `generate_paper2figure_json` 透传 palette/template 到 request - - 响应中新增 4 个 URL 字段(bw/color) -- `fastapi_app/workflow_adapters/wa_paper2figure.py` - - 从 state 读取 `svg_bw_*`、`svg_color_*` 并填充到 response - -### 5.3 State 扩展 -- `dataflow_agent/state.py` - - `Paper2FigureRequest`:新增 `tech_route_template`/`tech_route_palette` - - `Paper2FigureState`:新增黑白/彩色 SVG/PNG 路径字段 - -### 5.4 Workflow 拆分与校验重试 -- `dataflow_agent/workflow/wf_paper2technical.py` - - 原单节点生成 SVG 改为: - - `technical_route_bw_svg_generator`(黑白、模板参考、循环重试渲染校验) - - `technical_route_colorize_svg`(可选、仅输入黑白SVG+色卡、循环重试渲染校验) - - PPT 选择插入:选配色→彩色,否则→黑白 - -### 5.5 Agent 与 Prompt -- 新增 agents: - - `dataflow_agent/agentroles/paper2any_agents/technical_route_bw_svg_generator.py` - - `dataflow_agent/agentroles/paper2any_agents/technical_route_colorize_svg_agent.py` -- 新增 prompts(在现有模板文件中追加): - - `dataflow_agent/promptstemplates/resources/pt_technical_route_desc_generator_repo.py` - -### 5.6 前端交互与色卡预览 -- `frontend-workflow/src/components/paper2graph/constants.ts` - - 增加 `TECH_ROUTE_PALETTES`(4套色卡 + “不配色”) -- `frontend-workflow/src/components/paper2graph/index.tsx` - - localStorage 持久化 palette - - tech_route 请求时附带 `tech_route_palette` 与 `tech_route_template=temp.png` - - 解析并保存 `svg_bw_*`、`svg_color_*` -- `frontend-workflow/src/components/paper2graph/SettingsCard.tsx` - - 技术路线图时显示色卡下拉与颜色圆点预览 - - 选配色后显示黑白+彩色两套下载与链接 -- i18n: - - `frontend-workflow/src/locales/zh/paper2graph.json` 增加 `techRoute.paletteLabel` - - `frontend-workflow/src/locales/en/paper2graph.json` 增加 `techRoute.paletteLabel` - ---- - -## 6. 待办与风险点 - -### 6.1 待办 -- 跑一次最小联调/自测(建议): - - `TEXT` 模式不选配色:应只返回黑白 SVG/PNG - - `TEXT` 模式选配色:应返回黑白+彩色两套 SVG/PNG,且 PPT 插入彩色 - - `PDF` 模式同上 -- 前端可选:若你需要展示 PNG 预览链接(彩色/黑白),可在 SettingsCard 增加对应展示(当前只做了 SVG 下载/链接)。 - -### 6.2 风险点 -- VLM “understanding” 输出 SVG 的能力取决于模型服务对多模态的支持:若 `gpt-5.2` 在你的后端服务不支持 image-understanding,需要调整为支持多模态的 gpt-5.2 端点/代理。 -- 渲染校验依赖 CairoSVG 与字体环境:若部署环境缺字体,已在 workflow 里做了中文字体注入兜底,但仍可能存在缺字风险。 - ---- - -## 7. 关键文件清单(便于代码审查) - -- 模板: - - `static/paper2any_imgs/p2t/temp.png` -- Workflow: - - `dataflow_agent/workflow/wf_paper2technical.py` -- Agents: - - `dataflow_agent/agentroles/paper2any_agents/technical_route_bw_svg_generator.py` - - `dataflow_agent/agentroles/paper2any_agents/technical_route_colorize_svg_agent.py` -- Prompts: - - `dataflow_agent/promptstemplates/resources/pt_technical_route_desc_generator_repo.py` -- Backend: - - `fastapi_app/schemas.py` - - `fastapi_app/routers/paper2any.py` - - `fastapi_app/services/paper2any_service.py` - - `fastapi_app/workflow_adapters/wa_paper2figure.py` -- Frontend: - - `frontend-workflow/src/components/paper2graph/constants.ts` - - `frontend-workflow/src/components/paper2graph/index.tsx` - - `frontend-workflow/src/components/paper2graph/SettingsCard.tsx` - - `frontend-workflow/src/locales/zh/paper2graph.json` - - `frontend-workflow/src/locales/en/paper2graph.json` - diff --git a/fastapi_app/.env.example b/fastapi_app/.env.example index 667fbcd..d7a9809 100644 --- a/fastapi_app/.env.example +++ b/fastapi_app/.env.example @@ -1,70 +1,29 @@ # ============================================ # Supabase Configuration (Optional) # ============================================ -# For JWT authentication and file management +# For JWT authentication and user data management SUPABASE_URL=https://your-project-id.supabase.co SUPABASE_ANON_KEY=your_supabase_anon_key +SUPABASE_SERVICE_ROLE_KEY=your_supabase_service_role_key # ============================================ -# Model Configuration - Layer 1: Base Models +# TTS Configuration # ============================================ -# Define all available model names. These are referenced by workflow and role configurations. -MODEL_GPT_4O=gpt-4o -MODEL_GPT_5_1=gpt-5.1 -MODEL_CLAUDE_HAIKU=claude-haiku-4-5-20251001 -MODEL_GEMINI_PRO_IMAGE=gemini-3-pro-image-preview -MODEL_GEMINI_FLASH_IMAGE=gemini-2.5-flash-image -MODEL_GEMINI_FLASH=gemini-2.5-flash -MODEL_QWEN_VL_OCR=qwen-vl-ocr-2025-11-20 +# Enable local TTS (0=disabled, 1=enabled) +USE_LOCAL_TTS=1 -# Default LLM API URL (internal service) -DEFAULT_LLM_API_URL=http://123.129.219.111:3000/v1/ +# TTS engine: qwen (default) or firered +TTS_ENGINE=qwen + +# Auto-unload timeout in seconds (default: 300 = 5 minutes) +TTS_IDLE_TIMEOUT=300 # ============================================ -# Model Configuration - Layer 2: Workflow-level Defaults +# LLM API Configuration # ============================================ -# Default models for each workflow. Override these to change workflow-wide defaults. - -# Paper2PPT Workflow -PAPER2PPT_DEFAULT_MODEL=gpt-5.1 -PAPER2PPT_DEFAULT_IMAGE_MODEL=gemini-3-pro-image-preview - -# PDF2PPT Workflow -PDF2PPT_DEFAULT_MODEL=gpt-4o -PDF2PPT_DEFAULT_IMAGE_MODEL=gemini-2.5-flash-image - -# Paper2Figure Workflow -PAPER2FIGURE_DEFAULT_MODEL=gpt-4o -PAPER2FIGURE_DEFAULT_IMAGE_MODEL=gemini-3-pro-image-preview - -# Paper2Video Workflow -PAPER2VIDEO_DEFAULT_MODEL=gpt-4o - -# Paper2Drawio Workflow -PAPER2DRAWIO_DEFAULT_MODEL=gpt-4o -PAPER2DRAWIO_VLM_MODEL=gemini-2.5-flash -PAPER2DRAWIO_ENABLE_VLM_VALIDATION=false - -# Knowledge Base -KB_EMBEDDING_MODEL=gemini-2.5-flash -KB_CHAT_MODEL=gpt-4o +DEFAULT_LLM_API_URL=http://your-llm-api-url/v1/ # ============================================ -# Model Configuration - Layer 3: Role-level (Fine-grained Control) +# Knowledge Base Configuration # ============================================ -# Override specific roles within workflows for maximum control. - -# Paper2PPT Roles -PAPER2PPT_OUTLINE_MODEL=gpt-5.1 # Outline generation -PAPER2PPT_CONTENT_MODEL=gpt-5.1 # Content generation -PAPER2PPT_IMAGE_GEN_MODEL=gemini-3-pro-image-preview # Image generation -PAPER2PPT_VLM_MODEL=qwen-vl-ocr-2025-11-20 # Vision-language model (OCR) -PAPER2PPT_CHART_MODEL=gpt-4o # Chart generation -PAPER2PPT_DESC_MODEL=gpt-5.1 # Figure description - -# Paper2Figure Roles -PAPER2FIGURE_TEXT_MODEL=gpt-4o -PAPER2FIGURE_IMAGE_MODEL=gemini-3-pro-image-preview -PAPER2FIGURE_VLM_MODEL=qwen-vl-ocr-2025-11-20 -PAPER2FIGURE_CHART_MODEL=gpt-4o -PAPER2FIGURE_DESC_MODEL=gpt-5.1 +KB_CHAT_MODEL=deepseek-v3.2 diff --git a/fastapi_app/config/settings.py b/fastapi_app/config/settings.py index f74a464..13e8c8a 100644 --- a/fastapi_app/config/settings.py +++ b/fastapi_app/config/settings.py @@ -1,12 +1,8 @@ """ -Application Settings - Three-tier Configuration System +Application Settings -This module provides a centralized configuration system with three layers: -1. Base Models: Fundamental model name definitions -2. Workflow-level: Default models for each workflow -3. Role-level: Fine-grained model assignments for specific roles - -All settings can be overridden via environment variables in .env file. +Model configurations are used as Pydantic defaults in schemas.py. +Frontend typically overrides these values, but they're kept for API compatibility. """ from pydantic_settings import BaseSettings @@ -14,73 +10,26 @@ class AppSettings(BaseSettings): - """ - Application configuration using three-tier architecture: - Base Models + Workflow-level + Role-level - - Environment variables can override any setting by using the same name. - Example: export PAPER2PPT_DEFAULT_MODEL=deepseek-v3.2 - """ - - # ============================================ - # Layer 1: Base Model Definitions - # ============================================ - # 默认 LLM(文本模型)统一为 deepseek-v3.2,可通过环境变量覆盖 - DEFAULT_LLM_MODEL: str = "deepseek-v3.2" - MODEL_GPT_4O: str = "deepseek-v3.2" - MODEL_GPT_5_1: str = "deepseek-v3.2" - MODEL_CLAUDE_HAIKU: str = "claude-haiku-4-5-20251001" - MODEL_GEMINI_PRO_IMAGE: str = "gemini-3-pro-image-preview" - MODEL_GEMINI_FLASH_IMAGE: str = "gemini-2.5-flash-image" - MODEL_GEMINI_FLASH: str = "gemini-2.5-flash" - MODEL_QWEN_VL_OCR: str = "qwen-vl-ocr-2025-11-20" + """Application configuration with environment variable support.""" # API Configuration DEFAULT_LLM_API_URL: str = "http://123.129.219.111:3000/v1/" - # ============================================ - # Layer 2: Workflow-level Default Models - # ============================================ - # Paper2PPT Workflow - PAPER2PPT_DEFAULT_MODEL: str = "deepseek-v3.2" - PAPER2PPT_DEFAULT_IMAGE_MODEL: str = "gemini-3-pro-image-preview" - - # PDF2PPT Workflow - PDF2PPT_DEFAULT_MODEL: str = "deepseek-v3.2" - PDF2PPT_DEFAULT_IMAGE_MODEL: str = "gemini-2.5-flash-image" - - # Paper2Figure Workflow - PAPER2FIGURE_DEFAULT_MODEL: str = "deepseek-v3.2" - PAPER2FIGURE_DEFAULT_IMAGE_MODEL: str = "gemini-3-pro-image-preview" - - # Paper2Video Workflow + # Model defaults (used in schemas.py, typically overridden by frontend) + MODEL_GPT_4O: str = "deepseek-v3.2" PAPER2VIDEO_DEFAULT_MODEL: str = "deepseek-v3.2" - # Paper2Drawio Workflow - PAPER2DRAWIO_DEFAULT_MODEL: str = "deepseek-v3.2" - PAPER2DRAWIO_VLM_MODEL: str = "deepseek-v3.2" - PAPER2DRAWIO_ENABLE_VLM_VALIDATION: bool = False - - # Knowledge Base - KB_EMBEDDING_MODEL: str = "gemini-2.5-flash" - KB_CHAT_MODEL: str = "deepseek-v3.2" - - # Fast Research (web search for 引入) - SERPER_API_KEY: Optional[str] = None - - # ============================================ - # Layer 3: Role-level Model Configuration - # ============================================ - # Paper2PPT role-specific models - PAPER2PPT_OUTLINE_MODEL: str = "deepseek-v3.2" # Outline generation - PAPER2PPT_CONTENT_MODEL: str = "deepseek-v3.2" # Content generation - PAPER2PPT_IMAGE_GEN_MODEL: str = "gemini-3-pro-image-preview" # Image generation - PAPER2PPT_VLM_MODEL: str = "qwen-vl-ocr-2025-11-20" # VLM vision understanding - PAPER2PPT_CHART_MODEL: str = "deepseek-v3.2" # Chart generation - PAPER2PPT_DESC_MODEL: str = "deepseek-v3.2" # Figure description - PAPER2PPT_TECHNICAL_MODEL: str = "deepseek-v3.2" # Technical details - - # Paper2Figure role-specific models + # Paper2PPT models + PAPER2PPT_DEFAULT_MODEL: str = "deepseek-v3.2" + PAPER2PPT_OUTLINE_MODEL: str = "deepseek-v3.2" + PAPER2PPT_CONTENT_MODEL: str = "deepseek-v3.2" + PAPER2PPT_IMAGE_GEN_MODEL: str = "gemini-3-pro-image-preview" + PAPER2PPT_VLM_MODEL: str = "qwen-vl-ocr-2025-11-20" + PAPER2PPT_CHART_MODEL: str = "deepseek-v3.2" + PAPER2PPT_DESC_MODEL: str = "deepseek-v3.2" + PAPER2PPT_TECHNICAL_MODEL: str = "deepseek-v3.2" + + # Paper2Figure models PAPER2FIGURE_TEXT_MODEL: str = "deepseek-v3.2" PAPER2FIGURE_IMAGE_MODEL: str = "gemini-3-pro-image-preview" PAPER2FIGURE_VLM_MODEL: str = "qwen-vl-ocr-2025-11-20" @@ -89,6 +38,22 @@ class AppSettings(BaseSettings): PAPER2FIGURE_REF_IMG_DESC_MODEL: str = "deepseek-v3.2" PAPER2FIGURE_TECHNICAL_MODEL: str = "deepseek-v3.2" + # Knowledge Base + KB_CHAT_MODEL: str = "deepseek-v3.2" + + # Search API + SERPER_API_KEY: Optional[str] = None + + # Supabase + SUPABASE_URL: Optional[str] = None + SUPABASE_ANON_KEY: Optional[str] = None + SUPABASE_SERVICE_ROLE_KEY: Optional[str] = None + + # TTS + USE_LOCAL_TTS: int = 0 + TTS_ENGINE: str = "qwen" + TTS_IDLE_TIMEOUT: int = 300 + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/fastapi_app/deep_research/__init__.py b/fastapi_app/deep_research/__init__.py new file mode 100644 index 0000000..86a1016 --- /dev/null +++ b/fastapi_app/deep_research/__init__.py @@ -0,0 +1,21 @@ +""" +DeepResearch 模块 +阿里巴巴通义实验室的深度研究代理系统 +完整集成到 Open-NotebookLM +""" + +from .react_agent import MultiTurnReactAgent +from .tool_search import Search +from .tool_visit import Visit +from .tool_python import PythonInterpreter +from .tool_scholar import Scholar +from .tool_file import FileParser + +__all__ = [ + 'MultiTurnReactAgent', + 'Search', + 'Visit', + 'PythonInterpreter', + 'Scholar', + 'FileParser', +] diff --git a/fastapi_app/deep_research/file_tools/__init__.py b/fastapi_app/deep_research/file_tools/__init__.py new file mode 100644 index 0000000..0268726 --- /dev/null +++ b/fastapi_app/deep_research/file_tools/__init__.py @@ -0,0 +1,29 @@ +""" +File tools for DeepResearch +""" +from .file_parser import SingleFileParser, compress +from .video_agent import VideoAgent +from .video_analysis import VideoAnalysis +from .idp import IDP +from .utils import ( + get_file_type, + hash_sha256, + is_http_url, + get_basename_from_url, + sanitize_chrome_file_path, + save_url_to_local_work_dir +) + +__all__ = [ + 'SingleFileParser', + 'compress', + 'VideoAgent', + 'VideoAnalysis', + 'IDP', + 'get_file_type', + 'hash_sha256', + 'is_http_url', + 'get_basename_from_url', + 'sanitize_chrome_file_path', + 'save_url_to_local_work_dir', +] diff --git a/fastapi_app/deep_research/file_tools/file_parser.py b/fastapi_app/deep_research/file_tools/file_parser.py new file mode 100644 index 0000000..4f603ae --- /dev/null +++ b/fastapi_app/deep_research/file_tools/file_parser.py @@ -0,0 +1,578 @@ +import json +import os +import re +import time +import zipfile +import math +from pathlib import Path + +from typing import Any, Dict, List, Optional, Union +from collections import Counter +import xml.etree.ElementTree as ET +from pandas import Timestamp +from datetime import datetime +from pandas.api.types import is_datetime64_any_dtype + +import pandas as pd +from tabulate import tabulate +from qwen_agent.log import logger +from qwen_agent.settings import DEFAULT_WORKSPACE, DEFAULT_MAX_INPUT_TOKENS +from qwen_agent.tools.base import BaseTool, register_tool +from qwen_agent.tools.storage import KeyNotExistsError, Storage +from .utils import (get_file_type, hash_sha256, is_http_url, get_basename_from_url, + sanitize_chrome_file_path, save_url_to_local_work_dir) +from qwen_agent.utils.tokenization_qwen import count_tokens, tokenizer +from .idp import IDP + +# Configuration constants +PARSER_SUPPORTED_FILE_TYPES = ['pdf', 'docx', 'pptx', 'txt', 'html', 'csv', 'tsv', 'xlsx', 'xls', 'doc', 'zip', '.mp4', '.mov', '.mkv', '.webm', '.mp3', '.wav'] +def str_to_bool(value): + """Convert string to boolean, handling common true/false representations""" + if isinstance(value, bool): + return value + return str(value).lower() in ('true', '1', 'yes', 'on') +USE_IDP = str_to_bool(os.getenv("USE_IDP", "True")) +IDP_TIMEOUT = 150000 +ENABLE_CSI = False +PARAGRAPH_SPLIT_SYMBOL = '\n' + + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (datetime, Timestamp)): + return obj.isoformat() + return super().default(obj) + + +class FileParserError(Exception): + """Custom exception for document parsing errors""" + + def __init__(self, message: str, code: str = '400', exception: Optional[Exception] = None): + super().__init__(message) + self.code = code + self.exception = exception + + +def parse_file_by_idp(file_path: str = None, file_url: str = None) -> List[dict]: + idp = IDP() + try: + fid = idp.file_submit_with_url(file_url) if file_url else idp.file_submit_with_path(file_path) + if not fid: + return [] + + for _ in range(10): + result, status = idp.file_parser_query(fid) + if status == 'success': + return process_idp_result(result) + time.sleep(10) + + logger.error("IDP parsing timeout") + return [] + except Exception as e: + logger.error(f"IDP processing failed: {str(e)}") + return [] + + +def process_idp_result(result: dict) -> List[dict]: + pages = [] + current_page = None + + for layout in result.get('layouts', []): + page_num = layout.get('pageNum', 0) + content = layout.get('markdownContent', '') + + if current_page and current_page['page_num'] == page_num: + current_page['content'].append({'text': content}) + else: + current_page = {'page_num': page_num, 'content': [{'text': content}]} + pages.append(current_page) + + return pages + + +def clean_text(text: str) -> str: + cleaners = [ + lambda x: re.sub(r'\n+', '\n', x), + lambda x: x.replace("Add to Qwen's Reading List", ''), + lambda x: re.sub(r'-{6,}', '-----', x), + lambda x: x.strip() + ] + for cleaner in cleaners: + text = cleaner(text) + return text + + +def get_plain_doc(doc: list): + paras = [] + for page in doc: + for para in page['content']: + for k, v in para.items(): + if k in ['text', 'table', 'image']: + paras.append(v) + return PARAGRAPH_SPLIT_SYMBOL.join(paras) + + +def df_to_markdown(df: pd.DataFrame) -> str: + df = df.dropna(how='all').fillna('') + return tabulate(df, headers='keys', tablefmt='pipe', showindex=False) + + +def parse_word(docx_path: str, extract_image: bool = False): + if extract_image: + raise ValueError('Currently, extracting images is not supported!') + + from docx import Document + doc = Document(docx_path) + + content = [] + for para in doc.paragraphs: + content.append({'text': para.text}) + for table in doc.tables: + tbl = [] + for row in table.rows: + tbl.append('|' + '|'.join([cell.text for cell in row.cells]) + '|') + tbl = '\n'.join(tbl) + content.append({'table': tbl}) + return [{'page_num': 1, 'content': content}] + + +def parse_ppt(path: str, extract_image: bool = False): + if extract_image: + raise ValueError('Currently, extracting images is not supported!') + + from pptx import Presentation + from pptx.exc import PackageNotFoundError + try: + ppt = Presentation(path) + except PackageNotFoundError as ex: + logger.warning(ex) + return [] + doc = [] + for slide_number, slide in enumerate(ppt.slides): + page = {'page_num': slide_number + 1, 'content': []} + + for shape in slide.shapes: + if not shape.has_text_frame and not shape.has_table: + pass + + if shape.has_text_frame: + for paragraph in shape.text_frame.paragraphs: + paragraph_text = ''.join(run.text for run in paragraph.runs) + paragraph_text = clean_text(paragraph_text) + if paragraph_text.strip(): + page['content'].append({'text': paragraph_text}) + + if shape.has_table: + tbl = [] + for row_number, row in enumerate(shape.table.rows): + tbl.append('|' + '|'.join([cell.text for cell in row.cells]) + '|') + tbl = '\n'.join(tbl) + page['content'].append({'table': tbl}) + doc.append(page) + return doc + + +def parse_pdf(pdf_path: str, extract_image: bool = False) -> List[dict]: + # Todo: header and footer + from pdfminer.high_level import extract_pages + from pdfminer.layout import LTImage, LTRect, LTTextContainer + + doc = [] + import pdfplumber + pdf = pdfplumber.open(pdf_path) + for i, page_layout in enumerate(extract_pages(pdf_path)): + page = {'page_num': page_layout.pageid, 'content': []} + + elements = [] + for element in page_layout: + elements.append(element) + + # Init params for table + table_num = 0 + tables = [] + + for element in elements: + if isinstance(element, LTRect): + if not tables: + tables = extract_tables(pdf, i) + if table_num < len(tables): + table_string = table_converter(tables[table_num]) + table_num += 1 + if table_string: + page['content'].append({'table': table_string, 'obj': element}) + elif isinstance(element, LTTextContainer): + # Delete line breaks in the same paragraph + text = element.get_text() + # Todo: Further analysis using font + font = get_font(element) + if text.strip(): + new_content_item = {'text': text, 'obj': element} + if font: + new_content_item['font-size'] = round(font[1]) + # new_content_item['font-name'] = font[0] + page['content'].append(new_content_item) + elif extract_image and isinstance(element, LTImage): + # Todo: ocr + raise ValueError('Currently, extracting images is not supported!') + else: + pass + + # merge elements + page['content'] = postprocess_page_content(page['content']) + doc.append(page) + + return doc + + +def parse_txt(path: str): + with open(path, 'r', encoding='utf-8') as f: + text = f.read() + paras = text.split(PARAGRAPH_SPLIT_SYMBOL) + content = [] + for p in paras: + content.append({'text': p}) + return [{'page_num': 1, 'content': content}] + + +def get_font(element): + from pdfminer.layout import LTChar, LTTextContainer + + fonts_list = [] + for text_line in element: + if isinstance(text_line, LTTextContainer): + for character in text_line: + if isinstance(character, LTChar): + fonts_list.append((character.fontname, character.size)) + + fonts_list = list(set(fonts_list)) + if fonts_list: + counter = Counter(fonts_list) + most_common_fonts = counter.most_common(1)[0][0] + return most_common_fonts + else: + return [] + + +def extract_tables(pdf, page_num): + table_page = pdf.pages[page_num] + tables = table_page.extract_tables() + return tables + + +def table_converter(table): + table_string = '' + for row_num in range(len(table)): + row = table[row_num] + cleaned_row = [ + item.replace('\n', ' ') if item is not None and '\n' in item else 'None' if item is None else item + for item in row + ] + table_string += ('|' + '|'.join(cleaned_row) + '|' + '\n') + table_string = table_string[:-1] + return table_string + + +def postprocess_page_content(page_content: list) -> list: + # rm repetitive identification for table and text + # Some documents may repeatedly recognize LTRect and LTTextContainer + table_obj = [p['obj'] for p in page_content if 'table' in p] + tmp = [] + for p in page_content: + repetitive = False + if 'text' in p: + for t in table_obj: + if t.bbox[0] <= p['obj'].bbox[0] and p['obj'].bbox[1] <= t.bbox[1] and t.bbox[2] <= p['obj'].bbox[ + 2] and p['obj'].bbox[3] <= t.bbox[3]: + repetitive = True + break + + if not repetitive: + tmp.append(p) + page_content = tmp + + # merge paragraphs that have been separated by mistake + new_page_content = [] + for p in page_content: + if new_page_content and 'text' in new_page_content[-1] and 'text' in p and abs( + p.get('font-size', 12) - + new_page_content[-1].get('font-size', 12)) < 2 and p['obj'].height < p.get('font-size', 12) + 1: + # Merge those lines belonging to a paragraph + new_page_content[-1]['text'] += f' {p["text"]}' + # new_page_content[-1]['font-name'] = p.get('font-name', '') + new_page_content[-1]['font-size'] = p.get('font-size', 12) + else: + p.pop('obj') + new_page_content.append(p) + for i in range(len(new_page_content)): + if 'text' in new_page_content[i]: + new_page_content[i]['text'] = clean_text(new_page_content[i]['text']) + return new_page_content + + +def extract_xls_schema(file_path: str) -> Dict[str, Any]: + xls = pd.ExcelFile(file_path) + schema = { + "sheets": [], + "n_sheets": len(xls.sheet_names) + } + + for sheet_name in xls.sheet_names: + df = xls.parse(sheet_name, nrows=3) # 读取前3行 + + dtype_mapping = { + 'object': 'string', + 'datetime64[ns]': 'datetime', + 'timedelta64[ns]': 'timedelta' + } + dtypes = df.dtypes.astype(str).replace(dtype_mapping).to_dict() + + sample_df = df.head(3).copy() + for col in sample_df.columns: + if is_datetime64_any_dtype(sample_df[col]): + sample_df[col] = sample_df[col].dt.strftime('%Y-%m-%dT%H:%M:%S') + + sheet_info = { + "name": sheet_name, + "columns": df.columns.tolist(), + "dtypes": dtypes, + "sample_data": sample_df.to_dict(orient='list') + } + schema["sheets"].append(sheet_info) + + return schema + + +def extract_csv_schema(file_path: str) -> Dict[str, Any]: + df_dtype = pd.read_csv(file_path, nrows=100) + df_sample = pd.read_csv(file_path, nrows=3) + + return { + "columns": df_dtype.columns.tolist(), + "dtypes": df_dtype.dtypes.astype(str).to_dict(), + "sample_data": df_sample.to_dict(orient='list'), + "estimated_total_rows": _estimate_total_rows(file_path) + } + + +def _estimate_total_rows(file_path) -> int: + with open(file_path, 'rb') as f: + line_count = 0 + chunk_size = 1024 * 1024 + while chunk := f.read(chunk_size): + line_count += chunk.count(b'\n') + return line_count - 1 + + +def parse_tabular_file(file_path: str, **kwargs) -> List[dict]: + try: + df = pd.read_excel(file_path) if file_path.endswith(('.xlsx', '.xls')) else \ + pd.read_csv(file_path, **kwargs) + if count_tokens(df_to_markdown(df)) > DEFAULT_MAX_INPUT_TOKENS: + schema = extract_xls_schema(file_path) if file_path.endswith(('.xlsx', '.xls')) else \ + extract_csv_schema(file_path) + return [{'page_num': 1, 'content': [{'schema': schema}]}] + else: + return [{'page_num': 1, 'content': [{'table': df_to_markdown(df)}]}] + except Exception as e: + logger.error(f"Table parsing failed: {str(e)}") + return [] + + +def parse_zip(file_path: str, extract_dir: str) -> List[dict]: + with zipfile.ZipFile(file_path, 'r') as zip_ref: + zip_ref.extractall(extract_dir) + return [os.path.join(extract_dir, f) for f in zip_ref.namelist()] + + +def parse_html(file_path: str) -> List[dict]: + from bs4 import BeautifulSoup + + with open(file_path, 'r', encoding='utf-8') as f: + soup = BeautifulSoup(f, 'lxml') + + content = [{'text': clean_text(p.get_text())} + for p in soup.find_all(['p', 'div']) if p.get_text().strip()] + + return [{ + 'page_num': 1, + 'content': content, + 'title': soup.title.string if soup.title else '' + }] + + +def extract_xml_skeleton_markdown(xml_file): + tree = ET.parse(xml_file) + root = tree.getroot() + markdown_lines = [] + + def process_element(element, level=0, parent_path="", is_last=True, prefix=""): + if level > 0: + connector = "└── " if is_last else "├── " + markdown_lines.append(f"{prefix}{connector}**{element.tag}**") + else: + markdown_lines.append(f"## Root: {element.tag}") + + if element.attrib: + attrs = [f"`{k}`" for k in element.attrib.keys()] + attr_line = f"{prefix}{' ' if level > 0 else ''}*Attributes:* {', '.join(attrs)}" + markdown_lines.append(attr_line) + + if element.text and element.text.strip(): + text_line = f"{prefix}{' ' if level > 0 else ''}*Has text content*" + markdown_lines.append(text_line) + seen_tags = set() + unique_children = [] + for child in element: + if child.tag not in seen_tags: + seen_tags.add(child.tag) + unique_children.append(child) + + for i, child in enumerate(unique_children): + is_last_child = (i == len(unique_children) - 1) + child_prefix = prefix + (" " if is_last else "│ ") + process_element(child, level + 1, + f"{parent_path}/{element.tag}" if parent_path else element.tag, + is_last_child, child_prefix) + + process_element(root) + markdown_content = "\n".join(markdown_lines) + return markdown_content + + +def parse_xml(file_path: str) -> List[dict]: + with open(file_path, 'r', encoding='utf-8') as f: + text = f.read() + if count_tokens(text) > DEFAULT_MAX_INPUT_TOKENS: + schema = extract_xml_skeleton_markdown(file_path) + content = [{'schema': schema}] + else: + content = [{'text': text}] + return [{'page_num': 1, 'content': content}] + + +def compress(results: list) -> list[str]: + compress_results = [] + max_token = math.floor(DEFAULT_MAX_INPUT_TOKENS / len(results)) + for result in results: + token_list = tokenizer.tokenize(result) + token_list = token_list[:min(len(token_list), max_token)] + compress_results.append(tokenizer.convert_tokens_to_string(token_list)) + return compress_results + + +# @register_tool('file_parser') +class SingleFileParser(BaseTool): + name="file_parser" + description = f"File parsing tool, supports parsing data in {'/'.join(PARSER_SUPPORTED_FILE_TYPES)} formats, and returns the parsed markdown format data." + parameters = [{ + 'name': 'url', + 'type': 'string', + 'description': 'The full path of the file to be parsed, which can be a local path or a downloadable http(s) link.', + 'required': True + }] + + def __init__(self, cfg: Optional[Dict] = None): + super().__init__(cfg) + self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name)) + self.db = Storage({'storage_root_path': self.data_root}) + self.structured_doc = self.cfg.get('structured_doc', True) + + + self.parsers = { + 'pdf': parse_pdf, + 'docx': parse_word, + 'doc': parse_word, + 'pptx': parse_ppt, + 'txt': parse_txt, + 'jsonl': parse_txt, + 'jsonld': parse_txt, + 'pdb': parse_txt, + 'py': parse_txt, + 'html': parse_html, + 'xml': parse_xml, + 'csv': lambda p: parse_tabular_file(p, sep=','), + 'tsv': lambda p: parse_tabular_file(p, sep='\t'), + 'xlsx': parse_tabular_file, + 'xls': parse_tabular_file, + 'zip': self.parse_zip + } + + def call(self, params: Union[str, dict], **kwargs) -> Union[str, list]: + params = self._verify_json_format_args(params) + file_path = self._prepare_file(params['url']) + try: + cached = self.db.get(f'{hash_sha256(file_path)}_ori') + return self._flatten_result(json.loads(cached)) + except KeyNotExistsError: + return self._flatten_result(self._process_new_file(file_path)) + + def _prepare_file(self, path: str) -> str: + if is_http_url(path): + download_dir = os.path.join(self.data_root, hash_sha256(path)) + os.makedirs(download_dir, exist_ok=True) + return save_url_to_local_work_dir(path, download_dir) + return sanitize_chrome_file_path(path) + + def _process_new_file(self, file_path: str) -> Union[str, list]: + file_type = get_file_type(file_path) + idp_types = ['pdf', 'docx', 'pptx', 'xlsx', 'jpg', 'png', 'mp3'] + logger.info(f'Start parsing {file_path}...') + logger.info(f'File type {file_type}...') + logger.info(f"structured_doc {self.cfg.get('structured_doc')}...") + + if file_type not in idp_types: + file_type = get_basename_from_url(file_path).split('.')[-1].lower() + + try: + if USE_IDP and file_type in idp_types: + try: + results = parse_file_by_idp(file_path=file_path) + except Exception as e: + results = self.parsers[file_type](file_path) + else: + results = self.parsers[file_type](file_path) + tokens = 0 + for page in results: + for para in page['content']: + if 'schema' in para: + para['token'] = count_tokens(json.dumps(para['schema'])) + else: + para['token'] = count_tokens(para.get('text', para.get('table'))) + tokens += para['token'] + + if not results or not tokens: + logger.error(f"Parsing failed: No information was parsed") + raise FileParserError("Document parsing failed") + else: + self._cache_result(file_path, results) + return results + except Exception as e: + logger.error(f"Parsing failed: {str(e)}") + raise FileParserError("Document parsing failed", exception=e) + + def _cache_result(self, file_path: str, result: list): + cache_key = f'{hash_sha256(file_path)}_ori' + self.db.put(cache_key, json.dumps(result, ensure_ascii=False)) + logger.info(f'The parsing result of {file_path} has been cached') + + def _flatten_result(self, result: list) -> str: + return PARAGRAPH_SPLIT_SYMBOL.join( + para.get('text', para.get('table', '')) + for page in result for para in page['content'] + ) + + def parse_zip(self, file_path: str) -> List[dict]: + extract_dir = os.path.join(self.data_root, f"zip_{hash_sha256(file_path)}") + os.makedirs(extract_dir, exist_ok=True) + + results = [] + for extracted_file in parse_zip(file_path, extract_dir): + if (ft := get_file_type(extracted_file)) in self.parsers: + try: + results.extend(self.parsers[ft](extracted_file)) + except Exception as e: + logger.warning(f"Skip files {extracted_file}: {str(e)}") + + if not results: + raise ValueError("No parseable content found in the ZIP file") + return results diff --git a/fastapi_app/deep_research/file_tools/idp.py b/fastapi_app/deep_research/file_tools/idp.py new file mode 100644 index 0000000..71199cb --- /dev/null +++ b/fastapi_app/deep_research/file_tools/idp.py @@ -0,0 +1,90 @@ +import os +import json + +from alibabacloud_docmind_api20220711.client import Client as docmind_api20220711Client +from alibabacloud_tea_openapi import models as open_api_models +from alibabacloud_docmind_api20220711 import models as docmind_api20220711_models +from alibabacloud_tea_util.client import Client as UtilClient +from alibabacloud_tea_util import models as util_models +from alibabacloud_credentials.client import Client as CredClient + +key = os.environ.get('IDP_KEY_ID') +secret = os.environ.get('IDP_KEY_SECRET') + + +class IDP(): + def __init__(self): + config = open_api_models.Config( + access_key_id=key, + access_key_secret=secret + ) + config.endpoint = f'docmind-api.cn-hangzhou.aliyuncs.com' + self.client = docmind_api20220711Client(config) + + def file_submit_with_url(self, file_url): + print('parsing with document url ', file_url) + file_name = os.path.basename(file_url) + request = docmind_api20220711_models.SubmitDocParserJobAdvanceRequest( + file_url=file_url, + file_name=file_name, + reveal_markdown=True, + ) + runtime = util_models.RuntimeOptions() + result_dict = None + try: + response = self.client.submit_doc_parser_job_advance(request,runtime) + result_dict = response.body.data.id + except Exception as error: + UtilClient.assert_as_string(error.message) + + return result_dict + + + def file_submit_with_path(self, file_path): + print('parsing with document local path ', file_path) + file_name = os.path.basename(file_path) + request = docmind_api20220711_models.SubmitDocParserJobAdvanceRequest( + file_url_object=open(file_path, "rb"), + file_name=file_name, + ) + runtime = util_models.RuntimeOptions() + result_dict = None + try: + response = self.client.submit_doc_parser_job_advance(request, runtime) + result_dict = response.body.data.id + except Exception as error: + print(error) + UtilClient.assert_as_string(error.message) + + return result_dict + + def file_parser_query(self,fid): + request = docmind_api20220711_models.QueryDocParserStatusRequest( + id=fid + ) + try: + response = self.client.query_doc_parser_status(request) + NumberOfSuccessfulParsing = response.body.data + except Exception as e: + print(e) + return None + status_parse = response.body.data.status + NumberOfSuccessfulParsing = NumberOfSuccessfulParsing.__dict__ + responses = dict() + for i in range(0, NumberOfSuccessfulParsing["number_of_successful_parsing"], 3000): + request = docmind_api20220711_models.GetDocParserResultRequest( + id=fid, + layout_step_size=3000, + layout_num=i + ) + try: + response = self.client.get_doc_parser_result(request) + result = response.body.data + if not responses: + responses = result + else: + responses['layouts'].extend(result['layouts']) + except Exception as error: + return None,status_parse + return responses,status_parse + \ No newline at end of file diff --git a/fastapi_app/deep_research/file_tools/utils.py b/fastapi_app/deep_research/file_tools/utils.py new file mode 100644 index 0000000..6888b93 --- /dev/null +++ b/fastapi_app/deep_research/file_tools/utils.py @@ -0,0 +1,542 @@ +import base64 +import copy +import hashlib +import json +import os +import re +import shutil +import signal +import socket +import sys +import time +import traceback +import urllib.parse +from io import BytesIO +from typing import Any, List, Literal, Optional, Tuple, Union + +import json5 +import requests +from pydantic import BaseModel + +from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, FUNCTION, SYSTEM, USER, ContentItem, Message +from qwen_agent.log import logger + + +def append_signal_handler(sig, handler): + """ + Installs a new signal handler while preserving any existing handler. + If an existing handler is present, it will be called _after_ the new handler. + """ + + old_handler = signal.getsignal(sig) + if not callable(old_handler): + old_handler = None + if sig == signal.SIGINT: + + def old_handler(*args, **kwargs): + raise KeyboardInterrupt + elif sig == signal.SIGTERM: + + def old_handler(*args, **kwargs): + raise SystemExit + + def new_handler(*args, **kwargs): + handler(*args, **kwargs) + if old_handler is not None: + old_handler(*args, **kwargs) + + signal.signal(sig, new_handler) + + +def get_local_ip() -> str: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + # doesn't even have to be reachable + s.connect(('10.255.255.255', 1)) + ip = s.getsockname()[0] + except Exception: + ip = '127.0.0.1' + finally: + s.close() + return ip + + +def hash_sha256(text: str) -> str: + hash_object = hashlib.sha256(text.encode()) + key = hash_object.hexdigest() + return key + + +def print_traceback(is_error: bool = True): + tb = ''.join(traceback.format_exception(*sys.exc_info(), limit=3)) + if is_error: + logger.error(tb) + else: + logger.warning(tb) + + +CHINESE_CHAR_RE = re.compile(r'[\u4e00-\u9fff]') + + +def has_chinese_chars(data: Any) -> bool: + text = f'{data}' + return bool(CHINESE_CHAR_RE.search(text)) + + +def has_chinese_messages(messages: List[Union[Message, dict]], check_roles: Tuple[str] = (SYSTEM, USER)) -> bool: + for m in messages: + if m['role'] in check_roles: + if has_chinese_chars(m['content']): + return True + return False + + +def get_basename_from_url(path_or_url: str, need_rm_uuid: bool = False) -> str: + if re.match(r'^[A-Za-z]:\\', path_or_url): + # "C:\\a\\b\\c" -> "C:/a/b/c" + path_or_url = path_or_url.replace('\\', '/') + + # "/mnt/a/b/c" -> "c" + # "https://github.com/here?k=v" -> "here" + # "https://github.com/" -> "" + basename = urllib.parse.urlparse(path_or_url).path + basename = os.path.basename(basename) + basename = urllib.parse.unquote(basename) + basename = basename.strip() + + # "https://github.com/" -> "" -> "github.com" + if not basename: + basename = [x.strip() for x in path_or_url.split('/') if x.strip()][-1] + + new_basename = basename + if need_rm_uuid: + try: + # Hotfix: rm uuid + if len(basename) > 38 and basename[8] == '-' and basename[13] == '-' and basename[18] == '-' and basename[ + 23] == '-' and basename[36] == '_': + new_basename = basename[37:] + except Exception: + new_basename = basename + return new_basename + + +def is_http_url(path_or_url: str) -> bool: + if path_or_url.startswith('https://') or path_or_url.startswith('http://'): + return True + return False + + +def is_image(path_or_url: str) -> bool: + filename = get_basename_from_url(path_or_url).lower() + for ext in ['jpg', 'jpeg', 'png', 'webp']: + if filename.endswith(ext): + return True + return False + + +def sanitize_chrome_file_path(file_path: str) -> str: + if os.path.exists(file_path): + return file_path + + # Dealing with "file:///...": + new_path = urllib.parse.urlparse(file_path) + new_path = urllib.parse.unquote(new_path.path) + new_path = sanitize_windows_file_path(new_path) + if os.path.exists(new_path): + return new_path + + return sanitize_windows_file_path(file_path) + + +def sanitize_windows_file_path(file_path: str) -> str: + # For Linux and macOS. + if os.path.exists(file_path): + return file_path + + # For native Windows, drop the leading '/' in '/C:/' + win_path = file_path + if win_path.startswith('/'): + win_path = win_path[1:] + if os.path.exists(win_path): + return win_path + + # For Windows + WSL. + if re.match(r'^[A-Za-z]:/', win_path): + wsl_path = f'/mnt/{win_path[0].lower()}/{win_path[3:]}' + if os.path.exists(wsl_path): + return wsl_path + + # For native Windows, replace / with \. + win_path = win_path.replace('/', '\\') + if os.path.exists(win_path): + return win_path + + return file_path + + +def save_url_to_local_work_dir(url: str, save_dir: str, save_filename: str = '') -> str: + if not save_filename: + save_filename = get_basename_from_url(url) + new_path = os.path.join(save_dir, save_filename) + if os.path.exists(new_path): + os.remove(new_path) + # logger.info(f'Downloading {url} to {new_path}...') + start_time = time.time() + if not is_http_url(url): + url = sanitize_chrome_file_path(url) + shutil.copy(url, new_path) + else: + headers = { + 'User-Agent': + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' + } + response = requests.get(url, headers=headers) + if response.status_code == 200: + with open(new_path, 'wb') as file: + file.write(response.content) + else: + raise ValueError('Can not download this file. Please check your network or the file link.') + end_time = time.time() + # logger.info(f'Finished downloading {url} to {new_path}. Time spent: {end_time - start_time} seconds.') + return new_path + + +def save_text_to_file(path: str, text: str) -> None: + with open(path, 'w', encoding='utf-8') as fp: + fp.write(text) + + +def read_text_from_file(path: str) -> str: + try: + with open(path, 'r', encoding='utf-8') as file: + file_content = file.read() + except UnicodeDecodeError: + print_traceback(is_error=False) + from charset_normalizer import from_path + results = from_path(path) + file_content = str(results.best()) + return file_content + + +def contains_html_tags(text: str) -> bool: + pattern = r'<(p|span|div|li|html|script)[^>]*?' + return bool(re.search(pattern, text)) + + +def get_content_type_by_head_request(path: str) -> str: + try: + response = requests.head(path, timeout=5) + content_type = response.headers.get('Content-Type', '') + return content_type + except requests.RequestException: + return 'unk' + + +def get_file_type(path: str) -> Literal['pdf', 'docx', 'pptx', 'csv', 'tsv', 'xlsx', 'xls','zip','mp3','jsonl','pdb','py','xml']: + f_type = get_basename_from_url(path).split('.')[-1].lower() + if is_image(path): + return "image" + if f_type in ['pdf', 'docx', 'pptx', 'csv', 'tsv', 'xlsx', 'xls','zip','mp3','jsonl','pdb','py','xml']: + # Specially supported file types + return f_type + + if is_http_url(path): + # The HTTP header information for the response is obtained by making a HEAD request to the target URL, + # where the Content-type field usually indicates the Type of Content to be returned + content_type = get_content_type_by_head_request(path) + if 'application/pdf' in content_type: + return 'pdf' + elif 'application/msword' in content_type: + return 'docx' + + # Assuming that the URL is HTML by default, + # because the file downloaded by the request may contain html tags + return 'html' + else: + # Determine by reading local HTML file + try: + content = read_text_from_file(path) + except Exception: + print_traceback() + return 'unk' + + if contains_html_tags(content): + return 'html' + else: + return 'txt' + + +def extract_urls(text: str) -> List[str]: + pattern = re.compile(r'https?://\S+') + urls = re.findall(pattern, text) + return urls + + +def extract_markdown_urls(md_text: str) -> List[str]: + pattern = r'!?\[[^\]]*\]\(([^\)]+)\)' + urls = re.findall(pattern, md_text) + return urls + + +def extract_code(text: str) -> str: + # Match triple backtick blocks first + triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) + if triple_match: + text = triple_match.group(1) + else: + try: + text = json5.loads(text)['code'] + except Exception: + print_traceback(is_error=False) + # If no code blocks found, return original text + return text + + +def json_loads(text: str) -> dict: + text = text.strip('\n') + if text.startswith('```') and text.endswith('\n```'): + text = '\n'.join(text.split('\n')[1:-1]) + try: + return json.loads(text) + except json.decoder.JSONDecodeError as json_err: + try: + return json5.loads(text) + except ValueError: + raise json_err + + +class PydanticJSONEncoder(json.JSONEncoder): + + def default(self, obj): + if isinstance(obj, BaseModel): + return obj.model_dump() + return super().default(obj) + + +def json_dumps_pretty(obj: dict, ensure_ascii=False, indent=2, **kwargs) -> str: + return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent, cls=PydanticJSONEncoder, **kwargs) + + +def json_dumps_compact(obj: dict, ensure_ascii=False, indent=None, **kwargs) -> str: + return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent, cls=PydanticJSONEncoder, **kwargs) + + +def format_as_multimodal_message( + msg: Message, + add_upload_info: bool, + add_multimodel_upload_info: bool, + lang: Literal['auto', 'en', 'zh'] = 'auto', +) -> Message: + assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION) + content: List[ContentItem] = [] + if isinstance(msg.content, str): # if text content + if msg.content: + content = [ContentItem(text=msg.content)] + elif isinstance(msg.content, list): # if multimodal content + files = [] + for item in msg.content: + k, v = item.get_type_and_value() + if k in ('text', 'image', 'audio', 'video'): + content.append(item) + if k == 'file': + # Move 'file' out of 'content' since it's not natively supported by models + files.append((v, k)) + if add_multimodel_upload_info and k in ('image', 'video'): + # Indicate the image name + # TODO: considering audio + files.append((v, k)) + if add_upload_info and files and (msg.role in (SYSTEM, USER)): + if lang == 'auto': + has_zh = has_chinese_chars(msg) + else: + has_zh = (lang == 'zh') + upload = [] + # for f, k in [(get_basename_from_url(f, need_rm_uuid=True), k) for f, k in files]: + for f, k in [(f, k) for f, k in files]: + if k == 'image': + if has_zh: + upload.append(f'![图片]({f})') + else: + upload.append(f'![image]({f})') + elif k == 'video': + if has_zh: + upload.append(f'![视频]({f})') + else: + upload.append(f'![video]({f})') + else: + if has_zh: + upload.append(f'[文件]({f})') + else: + upload.append(f'[file]({f})') + upload = ' '.join(upload) + if has_zh: + upload = f'(上传了 {upload})\n\n' + else: + upload = f'(Uploaded {upload})\n\n' + + # Check and avoid adding duplicate upload info + upload_info_already_added = False + for item in content: + if item.text and (upload in item.text): + upload_info_already_added = True + + if not upload_info_already_added: + content = [ContentItem(text=upload)] + content + else: + raise TypeError + msg = Message(role=msg.role, + content=content, + name=msg.name if msg.role == FUNCTION else None, + function_call=msg.function_call, + extra=msg.extra) + return msg + + +def format_as_text_message( + msg: Message, + add_upload_info: bool, + lang: Literal['auto', 'en', 'zh'] = 'auto', +) -> Message: + msg = format_as_multimodal_message(msg, + add_upload_info=add_upload_info, + add_multimodel_upload_info=add_upload_info, + lang=lang) + text = '' + for item in msg.content: + if item.type == 'text': + text += item.value + msg.content = text + return msg + + +def extract_text_from_message( + msg: Message, + add_upload_info: bool, + lang: Literal['auto', 'en', 'zh'] = 'auto', +) -> str: + if isinstance(msg.content, list): + text = format_as_text_message(msg, add_upload_info=add_upload_info, lang=lang).content + elif isinstance(msg.content, str): + text = msg.content + else: + raise TypeError(f'List of str or str expected, but received {type(msg.content).__name__}.') + return text.strip() + + +def extract_files_from_messages(messages: List[Message], include_images: bool) -> List[str]: + files = [] + for msg in messages: + if isinstance(msg.content, list): + for item in msg.content: + if item.file and item.file not in files: + files.append(item.file) + if include_images and item.image and item.image not in files: + files.append(item.image) + return files + + +def merge_generate_cfgs(base_generate_cfg: Optional[dict], new_generate_cfg: Optional[dict]) -> dict: + generate_cfg: dict = copy.deepcopy(base_generate_cfg or {}) + if new_generate_cfg: + for k, v in new_generate_cfg.items(): + if k == 'stop': + stop = generate_cfg.get('stop', []) + stop = stop + [s for s in v if s not in stop] + generate_cfg['stop'] = stop + else: + generate_cfg[k] = v + return generate_cfg + + +def build_text_completion_prompt( + messages: List[Message], + allow_special: bool = False, + default_system: str = DEFAULT_SYSTEM_MESSAGE, +) -> str: + im_start = '<|im_start|>' + im_end = '<|im_end|>' + + if messages[0].role == SYSTEM: + sys = messages[0].content + assert isinstance(sys, str) + prompt = f'{im_start}{SYSTEM}\n{sys}{im_end}' + messages = messages[1:] + else: + prompt = f'{im_start}{SYSTEM}\n{default_system}{im_end}' + + # Make sure we are completing the chat in the tone of the assistant + if messages[-1].role != ASSISTANT: + messages = messages + [Message(ASSISTANT, '')] + + for msg in messages: + assert isinstance(msg.content, str) + content = msg.content + if allow_special: + assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION) + if msg.function_call: + assert msg.role == ASSISTANT + tool_call = msg.function_call.arguments + try: + tool_call = {'name': msg.function_call.name, 'arguments': json.loads(tool_call)} + tool_call = json.dumps(tool_call, ensure_ascii=False, indent=2) + except json.decoder.JSONDecodeError: + tool_call = '{"name": "' + msg.function_call.name + '", "arguments": ' + tool_call + '}' + if content: + content += '\n' + content += f'\n{tool_call}\n' + else: + assert msg.role in (USER, ASSISTANT) + assert msg.function_call is None + prompt += f'\n{im_start}{msg.role}\n{content}{im_end}' + + assert prompt.endswith(im_end) + prompt = prompt[:-len(im_end)] + return prompt + + +def encode_image_as_base64(path: str, max_short_side_length: int = -1) -> str: + from PIL import Image + image = Image.open(path) + + if (max_short_side_length > 0) and (min(image.size) > max_short_side_length): + ori_size = image.size + image = resize_image(image, short_side_length=max_short_side_length) + logger.debug(f'Image "{path}" resized from {ori_size} to {image.size}.') + + image = image.convert(mode='RGB') + buffered = BytesIO() + image.save(buffered, format='JPEG') + return 'data:image/jpeg;base64,' + base64.b64encode(buffered.getvalue()).decode('utf-8') + + +def load_image_from_base64(image_base64: Union[bytes, str]): + from PIL import Image + image = Image.open(BytesIO(base64.b64decode(image_base64))) + image.load() + return image + + +def resize_image(img, short_side_length: int = 1080): + from PIL import Image + assert isinstance(img, Image.Image) + + width, height = img.size + + if width <= height: + new_width = short_side_length + new_height = int((short_side_length / width) * height) + else: + new_height = short_side_length + new_width = int((short_side_length / height) * width) + + resized_img = img.resize((new_width, new_height), resample=Image.Resampling.BILINEAR) + return resized_img + + +def get_last_usr_msg_idx(messages: List[Union[dict, Message]]) -> int: + i = len(messages) - 1 + while (i >= 0) and (messages[i]['role'] != 'user'): + i -= 1 + assert i >= 0, messages + assert messages[i]['role'] == 'user' + return i diff --git a/fastapi_app/deep_research/file_tools/video_agent.py b/fastapi_app/deep_research/file_tools/video_agent.py new file mode 100644 index 0000000..5020a73 --- /dev/null +++ b/fastapi_app/deep_research/file_tools/video_agent.py @@ -0,0 +1,92 @@ +""" +input: + - query/goal: str + - Docs: List[file]/List[url] + - file type: 'pdf', 'docx', 'pptx', 'txt', 'html', 'csv', 'tsv', 'xlsx', 'xls', 'doc', 'zip', '.mp4', '.mov', '.avi', '.mkv', '.webm', '.mp3', '.wav', '.aac', '.ogg', '.flac' +output: + - answer: str + - useful_information: str +""" +import sys +import os +import re +import copy +import json +from typing import Dict, Iterator, List, Literal, Tuple, Union, Any, Optional +import json5 +import asyncio +from openai import OpenAI + +from qwen_agent.tools.base import BaseTool, register_tool +from qwen_agent.agents import Assistant +from qwen_agent.llm import BaseChatModel +from qwen_agent.llm.schema import ASSISTANT, USER, FUNCTION, Message, DEFAULT_SYSTEM_MESSAGE, SYSTEM, ROLE +from qwen_agent.tools import BaseTool +from qwen_agent.log import logger +from qwen_agent.utils.tokenization_qwen import count_tokens, tokenizer +from qwen_agent.settings import DEFAULT_WORKSPACE, DEFAULT_MAX_INPUT_TOKENS + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(current_dir)) + +from .video_analysis import VideoAnalysis + + +async def video_analysis(params, **kwargs): + """Modified video_analysis to handle multiple URLs""" + print(params) + files = params.get('files', []) + prompt = params.get('prompt', '') + + # Ensure URLs are in a list + + # Process each URL + results = [] + for file in files: + try: + # Create parameters for each URL + single_url_params = json.dumps({ + 'url': file, + 'prompt': prompt + }) + # Call the original VideoAnalysis tool + result = VideoAnalysis().call(single_url_params, **kwargs) + results.append(f"# Video: {os.path.basename(file)}\n{result}") + except Exception as e: + results.append(f"# Error processing {os.path.basename(file)}: {str(e)}") + + return results + + +@register_tool("VideoAgent") +class VideoAgent(BaseTool): + description = "Video/audio analysis with object detection, text extraction, scene understanding, and metadata analysis." + parameters = [ + { + 'name': 'query', + 'type': 'string', + 'description': 'Detailed question/instruction for analysis.', + 'required': True + }, + { + 'name': 'files', + 'type': 'array', + 'array_type': 'string', + 'description': 'The files to be analyzed.', + 'required': True + } + ] + + async def call(self, params): + response = await video_analysis(params) + return json.dumps(response, ensure_ascii=False) + + +if __name__ == "__main__": + agent = VideoAgent() + params = { + 'query': "Could you help me out with this assignment? Our professor sprung it on us at the end of class Friday, and I'm still trying to figure it out. The question he asked us was about an anagram. I've attached an audio recording of the question that he asked, so if you could please take a listen and give me the answer, I'd really appreciate the help. Please limit your response to the anagram text that could be generated from the original line which fulfills the professor's request, without any other commentary. Also, please don't include any punctuation in your response.", + 'files': ["datas/2b3ef98c-cc05-450b-a719-711aee40ac65.mp3"] + } + response = asyncio.run(agent.call(params)) + print(response) diff --git a/fastapi_app/deep_research/file_tools/video_analysis.py b/fastapi_app/deep_research/file_tools/video_analysis.py new file mode 100644 index 0000000..acc0a58 --- /dev/null +++ b/fastapi_app/deep_research/file_tools/video_analysis.py @@ -0,0 +1,619 @@ +import base64 +import io +import json +import os +import tempfile +import time +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, List, Literal, Optional, TypedDict, Union +from urllib.parse import urlparse + +import requests +from PIL import Image +from openai import OpenAI +from qwen_agent.log import logger +from qwen_agent.tools.base import BaseTool, register_tool + +# Configuration constants +MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB +SUPPORTED_VIDEO_TYPES = {'.mp4', '.mov', '.avi', '.mkv', '.webm'} +SUPPORTED_AUDIO_TYPES = {'.mp3', '.wav', '.aac', '.ogg', '.flac'} +DEFAULT_FRAMES = 8 +RETRY_ATTEMPTS = 3 +RETRY_DELAY = 1 + + +class AnalysisResult(TypedDict): + """Type definition for analysis results""" + status: Literal['success', 'error'] + data: Optional[Dict] + error: Optional[Dict] + + +@contextmanager +def temp_directory(): + """Context manager for temporary directory handling""" + temp_dir = tempfile.TemporaryDirectory() + try: + logger.debug(f"Created temp directory: {temp_dir.name}") + yield Path(temp_dir.name) + finally: + try: + temp_dir.cleanup() + logger.debug("Cleaned up temp directory") + except Exception as e: + logger.warning(f"Temp directory cleanup failed: {str(e)}") + + +@register_tool('video_analysis') +class VideoAnalysis(BaseTool): + """Improved tool for analyzing video and audio content""" + parameters = [ + { + 'name': 'prompt', + 'type': 'string', + 'description': 'Detailed question or instruction for video/audio analysis', + 'required': True + }, + { + 'name': 'url', + 'type': 'string', + 'description': 'Media file URL/path (supports video/audio)', + 'required': True + }, + { + 'name': 'num_frames', + 'type': 'number', + 'description': 'Number of key frames to extract (default: 8)', + 'required': False + } + ] + + def __init__(self, cfg: Optional[Dict] = None): + super().__init__(cfg or {}) + self.config = self._init_config(cfg or {}) + self.client = OpenAI( + api_key=self.config['api_key'], + base_url=self.config['api_base'], + timeout=self.config['timeout'] + ) + self.http_session = self._init_http_client() + self._check_dependencies() + logger.info("Video analysis tool initialized") + + def _init_config(self, cfg: Dict) -> Dict: + """Initialize configuration with sensible defaults""" + return { + 'api_key': os.getenv('DASHSCOPE_API_KEY', ''), + 'api_base': cfg.get('api_base') or os.getenv('DASHSCOPE_API_BASE', ''), + 'video_model': cfg.get('video_model') or os.getenv('VIDEO_MODEL_NAME', 'qwen-omni-turbo'), + 'analysis_model': cfg.get('analysis_model') or os.getenv('VIDEO_ANALYSIS_MODEL_NAME', 'qwen-plus-latest'), + 'timeout': min(cfg.get('timeout', 30), 300), # Cap at 300 seconds + 'max_frames': min(cfg.get('max_frames', 20), 50), # Cap at 50 frames + 'max_file_size': MAX_FILE_SIZE + } + + def _init_http_client(self) -> requests.Session: + """Initialize HTTP client with retry logic""" + session = requests.Session() + retry_adapter = requests.adapters.HTTPAdapter( + max_retries=requests.packages.urllib3.util.Retry( + total=RETRY_ATTEMPTS, + backoff_factor=RETRY_DELAY, + status_forcelist=[500, 502, 503, 504] + ) + ) + session.mount('http://', retry_adapter) + session.mount('https://', retry_adapter) + return session + + def _check_dependencies(self): + """Check for required dependencies""" + # Check for FFmpeg wrapper library + try: + # Try to import as ffmpeg_python to avoid name collisions + import ffmpeg as ffmpeg_lib + # Verify it's the correct library by checking for the input method + if hasattr(ffmpeg_lib, 'input'): + self.ffmpeg = ffmpeg_lib + logger.debug("Successfully loaded ffmpeg-python library") + else: + logger.warning( + "Found 'ffmpeg' module but it's not the ffmpeg-python package. Will use subprocess fallback.") + self.ffmpeg = None + except ImportError: + logger.warning("ffmpeg-python not installed. Will use subprocess fallback for media operations.") + self.ffmpeg = None + + # Check for scene detection capability + try: + from scenedetect import SceneManager, VideoManager + from scenedetect.detectors import ContentDetector + self._scene_detect_available = True + except ImportError: + logger.warning("SceneDetect not available. Using basic frame extraction.") + self._scene_detect_available = False + + def call(self, params: Union[str, Dict], **kwargs) -> AnalysisResult: + """Execute media analysis""" + result: AnalysisResult = { + 'status': 'success', + 'data': None, + 'error': None + } + + try: + # Parse and validate parameters + params = self._parse_params(params) + logger.info(f"Starting analysis task: {params['url']}") + + with temp_directory() as temp_dir: + # Process input file + media_path = self._process_input(params['url'], temp_dir) + self._validate_media_file(media_path) + + # Determine media type + is_audio = self._is_audio_only(media_path) + + # Audio transcription + audio_path = media_path if is_audio else self._extract_audio(media_path, temp_dir) + transcript = self._transcribe_audio(audio_path) + + # Key frame extraction (for videos only) + frames = [] + if not is_audio: + frames = self._extract_keyframes( + media_path, + min(params['num_frames'], self.config['max_frames']) + ) + + # AI analysis + analysis_result = self._analyze_media( + prompt=params['prompt'], + transcript=transcript, + frames=frames, + is_audio=is_audio + ) + + result['data'] = { + 'transcript': transcript, + 'frame_count': len(frames), + 'analysis': analysis_result + } + + except Exception as e: + result.update({ + 'status': 'error', + 'error': { + 'type': type(e).__name__, + 'message': str(e), + 'details': getattr(e, 'details', '') + } + }) + logger.error(f"Analysis failed: {str(e)}", exc_info=True) + + return result + + def _parse_params(self, params: Union[str, Dict]) -> Dict: + """Parse and validate parameters""" + if isinstance(params, str): + try: + params = json.loads(params) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON parameters: {str(e)}") + + required = ['url', 'prompt'] + missing = [f for f in required if f not in params] + if missing: + raise ValueError(f"Missing required parameters: {', '.join(missing)}") + + return { + 'url': params['url'], + 'prompt': params['prompt'], + 'num_frames': min( + int(params.get('num_frames', DEFAULT_FRAMES)), + self.config['max_frames'] + ) + } + + def _process_input(self, url: str, temp_dir: Path) -> Path: + """Process input URL/path and get local file path""" + parsed = urlparse(url) + if parsed.scheme in ('http', 'https'): + return self._download_media(url, temp_dir) + return self._resolve_local_path(url) + + def _get_video_duration(self, video_path: Path) -> float: + """Get video duration in seconds""" + # Try ffmpeg-python first + if self.ffmpeg: + try: + probe = self.ffmpeg.probe(str(video_path)) + return float(probe['format']['duration']) + except Exception as e: + logger.warning(f"ffmpeg-python probe failed: {str(e)}") + + # Fallback to subprocess + try: + import subprocess + import json + cmd = [ + 'ffprobe', '-v', 'error', + '-show_entries', 'format=duration', + '-of', 'json', str(video_path) + ] + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + data = json.loads(result.stdout) + return float(data['format']['duration']) + except Exception as e: + logger.warning(f"Subprocess duration check failed: {str(e)}") + # Default to a reasonable duration if all else fails + return 60.0 # Assume 1 minuteimport base64 + + def _download_media(self, url: str, save_dir: Path) -> Path: + """Download remote media file with validation""" + logger.info(f"Starting download: {url}") + + try: + # Pre-validate request + head_res = self.http_session.head(url, timeout=10) + head_res.raise_for_status() + + # Validate content type + content_type = head_res.headers.get('Content-Type', '') + file_ext = self._get_file_extension(content_type, url) + if not self._is_supported_type(file_ext): + raise ValueError(f"Unsupported file type: {file_ext}") + + # Validate file size + content_length = int(head_res.headers.get('Content-Length', 0)) + if content_length > self.config['max_file_size']: + raise ValueError( + f"File size ({content_length / 1e6:.2f}MB) exceeds limit ({self.config['max_file_size'] / 1e6}MB)" + ) + + # Download file in chunks + save_path = save_dir / f"media_{int(time.time())}{file_ext}" + with self.http_session.get(url, stream=True, timeout=self.config['timeout']) as res: + res.raise_for_status() + self._stream_write_file(res, save_path) + + logger.info(f"Download completed: {save_path}") + return save_path + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Download failed: {str(e)}") from e + + def _stream_write_file(self, response: requests.Response, save_path: Path) -> None: + """Stream file content to disk with progress tracking""" + total_size = 0 + start_time = time.time() + + with open(save_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + total_size += len(chunk) + + # Log progress periodically + if time.time() - start_time > 1: + logger.debug(f"Downloaded: {total_size / 1e6:.2f}MB") + start_time = time.time() + + if total_size > self.config['max_file_size']: + raise ValueError("File size exceeds limit") + + def _resolve_local_path(self, path: str) -> Path: + """Resolve local file path, handling relative paths""" + media_path = Path(path) + if not media_path.is_absolute(): + base_path = Path(os.getenv('PROJECT_ROOT', os.getcwd())) + media_path = base_path / media_path + + if not media_path.exists(): + raise FileNotFoundError(f"File not found: {media_path}") + return media_path + + def _validate_media_file(self, path: Path) -> None: + """Validate media file existence and size""" + if not path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + file_size = path.stat().st_size + if file_size > self.config['max_file_size']: + raise ValueError( + f"File size ({file_size / 1e6:.2f}MB) exceeds limit ({self.config['max_file_size'] / 1e6}MB)" + ) + + if not self._is_supported_type(path.suffix): + raise ValueError(f"Unsupported file type: {path.suffix}") + + def _is_supported_type(self, extension: str) -> bool: + """Check if file type is supported""" + ext = extension.lower().lstrip('.') + return ext in {ext.lstrip('.') for ext in SUPPORTED_VIDEO_TYPES | SUPPORTED_AUDIO_TYPES} + + def _get_file_extension(self, content_type: str, url: str) -> str: + """Get file extension from content type or URL""" + # Try from Content-Type + if content_type: + type_map = { + 'video/mp4': '.mp4', + 'video/quicktime': '.mov', + 'audio/mpeg': '.mp3', + 'audio/wav': '.wav', + 'audio/aac': '.aac' + } + if ext := type_map.get(content_type.split(';')[0]): + return ext + + # Try from URL path + if path_ext := Path(urlparse(url).path).suffix: + return path_ext + + return '.mp4' # Default extension + + def _is_audio_only(self, path: Path) -> bool: + """Detect if file is audio-only""" + # Check by extension first + if path.suffix.lower() in SUPPORTED_AUDIO_TYPES: + return True + + # Then try to use ffmpeg probe + if self.ffmpeg: + try: + probe = self.ffmpeg.probe(str(path)) + return not any(s['codec_type'] == 'video' for s in probe['streams']) + except Exception as e: + logger.warning(f"Media probe failed: {str(e)}") + + # If ffmpeg-python not available, use subprocess + try: + import subprocess + cmd = ['ffprobe', '-v', 'error', '-show_entries', + 'stream=codec_type', '-of', 'json', str(path)] + result = subprocess.run(cmd, check=True, capture_output=True, text=True) + import json + probe_data = json.loads(result.stdout) + return not any(s.get('codec_type') == 'video' + for s in probe_data.get('streams', [])) + except Exception as e: + logger.warning(f"Subprocess probe failed: {str(e)}") + # If all else fails, use file extension + return path.suffix.lower() in SUPPORTED_AUDIO_TYPES + + def _extract_audio(self, video_path: Path, temp_dir: Path) -> Path: + """Extract audio from video""" + logger.info(f"Extracting audio: {video_path}") + output_path = temp_dir / f"audio_{video_path.stem}.mp3" + + # First try using ffmpeg-python if available + if self.ffmpeg: + try: + ( + self.ffmpeg.input(str(video_path)) + .output(str(output_path), vn=None, acodec='libmp3lame', loglevel='error') + .run(overwrite_output=True) + ) + return output_path + except Exception as e: + logger.warning(f"ffmpeg-python extraction failed: {str(e)}. Trying subprocess fallback.") + # Fall through to subprocess method + + # Fallback to direct subprocess call + try: + import subprocess + cmd = [ + 'ffmpeg', '-i', str(video_path), + '-vn', '-acodec', 'libmp3lame', + '-y', str(output_path) + ] + subprocess.run(cmd, check=True, capture_output=True) + return output_path + except subprocess.SubprocessError as e: + raise RuntimeError(f"Audio extraction failed: {str(e)}") from e + except Exception as e: + raise RuntimeError(f"Audio extraction failed: {str(e)}") from e + + def _transcribe_audio(self, audio_path: Path) -> str: + """Transcribe audio to text""" + logger.info(f"Starting transcription: {audio_path}") + start_time = time.time() + + try: + with open(audio_path, 'rb') as f: + base64_audio = base64.b64encode(f.read()).decode() + + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "Completely transcribe this audio content with all details"}, + { + "type": "input_audio", + "input_audio": { + "data": f"data:audio/mp3;base64,{base64_audio}", + "format": "mp3" + } + } + ] + }] + response = self.client.chat.completions.create( + model=self.config['video_model'], + messages=messages, + stream=True + ) + + transcript = [] + for chunk in response: + if chunk.choices and chunk.choices[0].delta.content: + transcript.append(chunk.choices[0].delta.content) + + final_text = ''.join(transcript).strip() + logger.info(f"Transcription completed (time: {time.time() - start_time:.1f}s, chars: {len(final_text)})") + return final_text + + except Exception as e: + logger.error(f"Transcription failed: {str(e)}") + return "" + + def _extract_keyframes(self, video_path: Path, num_frames: int) -> List[str]: + """Extract key frames intelligently""" + logger.info(f"Extracting key frames: {video_path}") + frames = [] + + try: + import ffmpeg + + # Use scene detection if available + if self._scene_detect_available: + frames = self._extract_frames_with_scene_detection(video_path, num_frames) + else: + frames = self._extract_frames_uniform(video_path, num_frames) + + return frames + + except ImportError as e: + logger.error(f"Missing dependency: {str(e)}") + return [] + except Exception as e: + logger.error(f"Frame extraction failed: {str(e)}") + return [] + + def _extract_frames_with_scene_detection(self, video_path: Path, num_frames: int) -> List[str]: + """Extract frames based on scene changes""" + try: + from scenedetect import detect, ContentDetector + + # Detect scene changes + scene_list = detect(str(video_path), ContentDetector(threshold=30)) + timestamps = [scene[0].get_seconds() for scene in scene_list] + + # Get video duration + duration = self._get_video_duration(video_path) + + # If no scenes detected or too few, use uniform sampling + if not timestamps or len(timestamps) < num_frames: + # Calculate how many additional frames we need + additional_needed = num_frames - len(timestamps) + if additional_needed > 0: + # Create evenly spaced timestamps for remaining frames + interval = duration / (additional_needed + 1) + extra_timestamps = [interval * (i + 1) for i in range(additional_needed)] + timestamps.extend(extra_timestamps) + timestamps.sort() + + # If too many scenes detected, select a representative subset + if len(timestamps) > num_frames: + step = len(timestamps) // num_frames + timestamps = [timestamps[i] for i in range(0, len(timestamps), step)][:num_frames] + + # Capture frames at timestamps + return [ + self._frame_to_base64(self._capture_frame(video_path, ts)) + for ts in timestamps[:num_frames] + if self._capture_frame(video_path, ts) is not None + ] + + except Exception as e: + logger.warning(f"Scene detection failed, falling back to uniform sampling: {str(e)}") + return self._extract_frames_uniform(video_path, num_frames) + + def _extract_frames_uniform(self, video_path: Path, num_frames: int) -> List[str]: + """Extract frames at uniform intervals""" + try: + # Get video duration + duration = self._get_video_duration(video_path) + + # Calculate evenly spaced timestamps + interval = duration / (num_frames + 1) + timestamps = [interval * (i + 1) for i in range(num_frames)] + + # Capture frames + return [ + self._frame_to_base64(self._capture_frame(video_path, ts)) + for ts in timestamps + if self._capture_frame(video_path, ts) is not None + ] + + except Exception as e: + logger.error(f"Uniform frame extraction failed: {str(e)}") + return [] + + def _capture_frame(self, video_path: Path, timestamp: float) -> Optional[Image.Image]: + """Capture a video frame at specified timestamp""" + output_file = video_path.parent / f"frame_{timestamp}.jpg" + + # Try ffmpeg-python if available + if self.ffmpeg: + try: + ( + self.ffmpeg.input(str(video_path), ss=timestamp) + .output(str(output_file), vframes=1, q=2, loglevel='error') + .run(overwrite_output=True) + ) + return Image.open(output_file) + except Exception as e: + logger.warning(f"ffmpeg-python frame capture failed: {str(e)}") + # Fall through to subprocess method + + # Fallback to subprocess + try: + import subprocess + cmd = [ + 'ffmpeg', '-ss', str(timestamp), + '-i', str(video_path), + '-vframes', '1', '-q:v', '2', + '-y', str(output_file) + ] + subprocess.run(cmd, check=True, capture_output=True) + return Image.open(output_file) + except Exception as e: + logger.warning(f"Frame capture failed at {timestamp}s: {str(e)}") + return None + + def _frame_to_base64(self, image: Image.Image) -> str: + """Convert image to base64 string""" + buffered = io.BytesIO() + image.save(buffered, format="JPEG", quality=85, optimize=True) + return base64.b64encode(buffered.getvalue()).decode() + + def _analyze_media(self, prompt: str, transcript: str, frames: List[str], is_audio: bool) -> str: + """Analyze media using AI model""" + logger.info(f"Starting AI analysis ({'audio' if is_audio else 'video'})") + messages = self._build_analysis_messages(prompt, transcript, frames, is_audio) + try: + response = self.client.chat.completions.create( + model=self.config['analysis_model'], + messages=messages, + temperature=0.3, + ) + return response.choices[0].message.content + except Exception as e: + logger.error(f"AI analysis failed: {str(e)}") + return "Analysis generation failed" + + def _build_analysis_messages(self, prompt: str, transcript: str, frames: List[str], is_audio: bool) -> List[Dict]: + """Build prompt messages for analysis""" + system_prompt = ( + f"You are a professional {'audio' if is_audio else 'video'} analysis expert. " + "Your task is to analyze the provided content by:\n" + "1. Identifying key information and contextual relationships\n" + "2. Noting time-sequence information\n" + "3. Providing expert answers to the user's question" + ) + + content = [ + {"type": "text", "text": f"User question: {prompt}\n\nAudio transcription:\n{transcript}"} + ] + + if not is_audio: + content.extend([ + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}} + for img in frames + ]) + + return [ + {"role": "system", "content": [{"type": "text", "text": system_prompt}]}, + {"role": "user", "content": content} + ] \ No newline at end of file diff --git a/fastapi_app/deep_research/prompt.py b/fastapi_app/deep_research/prompt.py new file mode 100644 index 0000000..115bdb9 --- /dev/null +++ b/fastapi_app/deep_research/prompt.py @@ -0,0 +1,51 @@ +SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You must handle both broad, open-domain inquiries and queries within specialized academic fields. For every request, synthesize information from credible, diverse sources to deliver a comprehensive, accurate, and objective response. When you have gathered sufficient information and are ready to provide the definitive response, you must enclose the entire final answer within tags. + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type": "function", "function": {"name": "search", "description": "Perform Google web searches then returns a string of the top search results. Accepts multiple queries.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "visit", "description": "Visit webpage(s) and return the summary of the content.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs."}, "goal": {"type": "string", "description": "The specific information goal for visiting webpage(s)."}}, "required": ["url", "goal"]}}} +{"type": "function", "function": {"name": "PythonInterpreter", "description": "Executes Python code in a sandboxed environment. To use this tool, you must follow this format: +1. The 'arguments' JSON object must be empty: {}. +2. The Python code to be executed must be placed immediately after the JSON block, enclosed within and tags. + +IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function. + +Example of a correct call: + +{"name": "PythonInterpreter", "arguments": {}} + +import numpy as np +# Your code here +print(f"The result is: {np.mean([1,2,3])}") + +", "parameters": {"type": "object", "properties": {}, "required": []}}} +{"type": "function", "function": {"name": "google_scholar", "description": "Leverage Google Scholar to retrieve relevant information from academic publications. Accepts multiple queries. This tool will also return results from google search", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries for Google Scholar."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "parse_file", "description": "This is a tool that can be used to parse multiple user uploaded local files such as PDF, DOCX, PPTX, TXT, CSV, XLSX, DOC, ZIP, MP4, MP3.", "parameters": {"type": "object", "properties": {"files": {"type": "array", "items": {"type": "string"}, "description": "The file name of the user uploaded local files to be parsed."}}, "required": ["files"]}}} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + + +Current date: """ + +EXTRACTOR_PROMPT = """Please process the following webpage content and user goal to extract relevant information: + +## **Webpage Content** +{webpage_content} + +## **User Goal** +{goal} + +## **Task Guidelines** +1. **Content Scanning for Rationale**: Locate the **specific sections/data** directly related to the user's goal within the webpage content +2. **Key Extraction for Evidence**: Identify and extract the **most relevant information** from the content, you never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs. +3. **Summary Output for Summary**: Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal. + +**Final Output Format using JSON format has "rational", "evidence", "summary" feilds** +""" diff --git a/fastapi_app/deep_research/react_agent.py b/fastapi_app/deep_research/react_agent.py new file mode 100644 index 0000000..66d7042 --- /dev/null +++ b/fastapi_app/deep_research/react_agent.py @@ -0,0 +1,279 @@ +import json +import json5 +import os +from typing import Dict, Iterator, List, Literal, Optional, Tuple, Union +from qwen_agent.llm.schema import Message +from qwen_agent.utils.utils import build_text_completion_prompt +from openai import OpenAI, APIError, APIConnectionError, APITimeoutError +from transformers import AutoTokenizer +from datetime import datetime +from qwen_agent.agents.fncall_agent import FnCallAgent +from qwen_agent.llm import BaseChatModel +from qwen_agent.llm.schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, Message +from qwen_agent.settings import MAX_LLM_CALL_PER_RUN +from qwen_agent.tools import BaseTool +from qwen_agent.utils.utils import format_as_text_message, merge_generate_cfgs +from .prompt import * +import time +import asyncio + +from .tool_file import * +from .tool_scholar import * +from .tool_python import * +from .tool_search import * +from .tool_visit import * + +OBS_START = '' +OBS_END = '\n' + +MAX_LLM_CALL_PER_RUN = int(os.getenv('MAX_LLM_CALL_PER_RUN', 100)) + +TOOL_CLASS = [ + FileParser(), + Scholar(), + Visit(), + Search(), + PythonInterpreter(), +] +TOOL_MAP = {tool.name: tool for tool in TOOL_CLASS} + +import random +import datetime + + +def today_date(): + return datetime.date.today().strftime("%Y-%m-%d") + +class MultiTurnReactAgent(FnCallAgent): + def __init__(self, + function_list: Optional[List[Union[str, Dict, BaseTool]]] = None, + llm: Optional[Union[Dict, BaseChatModel]] = None, + **kwargs): + + self.llm_generate_cfg = llm["generate_cfg"] + self.llm_local_path = llm["model"] + self.api_key = llm.get("api_key", "EMPTY") # 保存 API key + self.api_base = llm.get("api_base", "") # 保存 API base URL + + def sanity_check_output(self, content): + return "" in content and "" in content + + def call_server(self, msgs, planning_port, max_tries=3): + + # 使用传入的 planning_port 构造 API base URL + # planning_port 实际上是完整的 API base URL(从 api_base 传递过来) + if isinstance(planning_port, str) and ('http://' in planning_port or 'https://' in planning_port): + # 如果 planning_port 是完整的 URL,直接使用 + openai_api_base = planning_port if planning_port.endswith('/v1') else f"{planning_port}/v1" + else: + # 否则假设是端口号 + openai_api_base = f"http://127.0.0.1:{planning_port}/v1" + + # 使用实际的 API key,而不是硬编码的 "EMPTY" + openai_api_key = self.api_key + + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + timeout=600.0, + ) + + base_sleep_time = 1 + for attempt in range(max_tries): + try: + print(f"--- Attempting to call the service, try {attempt + 1}/{max_tries} ---") + chat_response = client.chat.completions.create( + model=self.model, + messages=msgs, + stop=["\n", ""], + temperature=self.llm_generate_cfg.get('temperature', 0.6), + top_p=self.llm_generate_cfg.get('top_p', 0.95), + logprobs=True, + max_tokens=10000, + presence_penalty=self.llm_generate_cfg.get('presence_penalty', 1.1) + ) + content = chat_response.choices[0].message.content + + # OpenRouter provides API calling. If you want to use OpenRouter, you need to uncomment line 89 - 90. + # reasoning_content = "\n" + chat_response.choices[0].message.reasoning.strip() + "\n" + # content = reasoning_content + content + + if content and content.strip(): + print("--- Service call successful, received a valid response ---") + return content.strip() + else: + print(f"Warning: Attempt {attempt + 1} received an empty response.") + + except (APIError, APIConnectionError, APITimeoutError) as e: + print(f"Error: Attempt {attempt + 1} failed with an API or network error: {e}") + except Exception as e: + print(f"Error: Attempt {attempt + 1} failed with an unexpected error: {e}") + + if attempt < max_tries - 1: + sleep_time = base_sleep_time * (2 ** attempt) + random.uniform(0, 1) + sleep_time = min(sleep_time, 10) # 最大等待 10 秒 + + print(f"Retrying in {sleep_time:.2f} seconds...") + time.sleep(sleep_time) + else: + print("Error: All retry attempts have been exhausted. The call has failed.") + + return f"vllm server error!!!" + + def count_tokens(self, messages): + """ + 估算消息的 token 数量 + 使用 tiktoken 或简单估算,避免依赖特定模型的 tokenizer + """ + try: + # 方案 1: 尝试使用 tiktoken(OpenAI 的 tokenizer) + import tiktoken + encoding = tiktoken.get_encoding("cl100k_base") + + # 将消息转换为文本 + full_text = "" + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + full_text += f"{role}: {content}\n" + + token_count = len(encoding.encode(full_text)) + except: + # 方案 2: 如果 tiktoken 不可用,使用简单估算 + # 平均每个 token 约 4 个字符(英文)或 1.5 个字符(中文) + full_text = "" + for msg in messages: + content = msg.get("content", "") + full_text += content + # 使用保守估算:每 3 个字符 = 1 token + token_count = len(full_text) // 3 + return token_count + + def _run(self, data: str, model: str, **kwargs) -> List[List[Message]]: + self.model=model + try: + question = data['item']['question'] + except: + raw_msg = data['item']['messages'][1]["content"] + question = raw_msg.split("User:")[1].strip() if "User:" in raw_msg else raw_msg + + start_time = time.time() + planning_port = data['planning_port'] + answer = data['item']['answer'] + self.user_prompt = question + system_prompt = SYSTEM_PROMPT + cur_date = today_date() + system_prompt = system_prompt + str(cur_date) + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}] + num_llm_calls_available = MAX_LLM_CALL_PER_RUN + round = 0 + while num_llm_calls_available > 0: + # Check whether time is reached + if time.time() - start_time > 150 * 60: # 150 minutes in seconds + prediction = 'No answer found after 2h30mins' + termination = 'No answer found after 2h30mins' + result = { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": termination + } + return result + round += 1 + num_llm_calls_available -= 1 + content = self.call_server(messages, planning_port) + print(f'Round {round}: {content}') + if '' in content: + pos = content.find('') + content = content[:pos] + messages.append({"role": "assistant", "content": content.strip()}) + if '' in content and '' in content: + tool_call = content.split('')[1].split('')[0] + try: + if "python" in tool_call.lower(): + try: + code_raw=content.split('')[1].split('')[0].split('')[1].split('')[0].strip() + result = TOOL_MAP['PythonInterpreter'].call(code_raw) + except: + result = "[Python Interpreter Error]: Formatting error." + + else: + tool_call = json5.loads(tool_call) + tool_name = tool_call.get('name', '') + tool_args = tool_call.get('arguments', {}) + result = self.custom_call_tool(tool_name, tool_args) + + except: + result = 'Error: Tool call is not a valid JSON. Tool call must contain a valid "name" and "arguments" field.' + result = "\n" + result + "\n" + # print(result) + messages.append({"role": "user", "content": result}) + if '' in content and '' in content: + termination = 'answer' + break + if num_llm_calls_available <= 0 and '' not in content: + messages[-1]['content'] = 'Sorry, the number of llm calls exceeds the limit.' + + max_tokens = 110 * 1024 + token_count = self.count_tokens(messages) + print(f"round: {round}, token count: {token_count}") + + if token_count > max_tokens: + print(f"Token quantity exceeds the limit: {token_count} > {max_tokens}") + + messages[-1]['content'] = "You have now reached the maximum context length you can handle. You should stop making tool calls and, based on all the information above, think again and provide what you consider the most likely answer in the following format:your final thinking\nyour answer" + content = self.call_server(messages, planning_port) + messages.append({"role": "assistant", "content": content.strip()}) + if '' in content and '' in content: + prediction = messages[-1]['content'].split('')[1].split('')[0] + termination = 'generate an answer as token limit reached' + else: + prediction = messages[-1]['content'] + termination = 'format error: generate an answer as token limit reached' + result = { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": termination + } + return result + + if '' in messages[-1]['content']: + prediction = messages[-1]['content'].split('')[1].split('')[0] + termination = 'answer' + else: + prediction = 'No answer found.' + termination = 'answer not found' + if num_llm_calls_available == 0: + termination = 'exceed available llm calls' + result = { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction, + "termination": termination + } + return result + + def custom_call_tool(self, tool_name: str, tool_args: dict, **kwargs): + if tool_name in TOOL_MAP: + tool_args["params"] = tool_args + if "python" in tool_name.lower(): + result = TOOL_MAP['PythonInterpreter'].call(tool_args) + elif tool_name == "parse_file": + params = {"files": tool_args["files"]} + + raw_result = asyncio.run(TOOL_MAP[tool_name].call(params, file_root_path="./eval_data/file_corpus")) + result = raw_result + + if not isinstance(raw_result, str): + result = str(raw_result) + else: + raw_result = TOOL_MAP[tool_name].call(tool_args, **kwargs) + result = raw_result + return result + + else: + return f"Error: Tool {tool_name} not found" diff --git a/fastapi_app/deep_research/tool_file.py b/fastapi_app/deep_research/tool_file.py new file mode 100644 index 0000000..c9ab69f --- /dev/null +++ b/fastapi_app/deep_research/tool_file.py @@ -0,0 +1,141 @@ +""" +input: + - query/goal: str + - Docs: List[file]/List[url] + - file type: 'pdf', 'docx', 'pptx', 'txt', 'html', 'csv', 'tsv', 'xlsx', 'xls', 'doc', 'zip', '.mp4', '.mov', '.avi', '.mkv', '.webm', '.mp3', '.wav', '.aac', '.ogg', '.flac' +output: + - answer: str + - useful_information: str +""" +import sys +import os +import re +import time +import copy +import json +from typing import Dict, Iterator, List, Literal, Tuple, Union, Any, Optional +import json5 +import asyncio +from openai import OpenAI, AsyncOpenAI +import pdb +import bdb + +from qwen_agent.tools.base import BaseTool, register_tool +from qwen_agent.agents import Assistant +from qwen_agent.llm import BaseChatModel +from qwen_agent.settings import DEFAULT_WORKSPACE, DEFAULT_MAX_INPUT_TOKENS +from qwen_agent.llm.schema import ASSISTANT, USER, FUNCTION, Message, DEFAULT_SYSTEM_MESSAGE, SYSTEM, ROLE +from qwen_agent.tools import BaseTool +from qwen_agent.log import logger +from qwen_agent.utils.tokenization_qwen import count_tokens, tokenizer +from qwen_agent.settings import DEFAULT_WORKSPACE, DEFAULT_MAX_INPUT_TOKENS + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(current_dir)) +sys.path.append('../../') + +from .file_tools.file_parser import SingleFileParser, compress +from .file_tools.video_agent import VideoAgent + +FILE_SUMMARY_PROMPT = """ +Please process the following file content and user goal to extract relevant information: + +## **File Content** +{file_content} + +## **User Goal** +{goal} + +## **Task Guidelines** +1. **Content Scanning for Rational**: Locate the **specific sections/data** directly related to the user's goal within the file content +2. **Key Extraction for Evidence**: Identify and extract the **most relevant information** from the content, you never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs. +3. **Summary Output for Summary**: Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal. +""".strip() + + +async def file_parser(params, **kwargs): + """Parse files with automatic path resolution""" + urls = params.get('files', []) + if isinstance(urls, str): + urls = [urls] + + resolved_urls = [] + for url in urls: + if isinstance(url, list): + for sub_url in url: + if sub_url.startswith(("http://", "https://")): + resolved_urls.append(sub_url) + else: + abs_path = os.path.abspath(sub_url) + if os.path.exists(abs_path): + resolved_urls.append(abs_path) + else: + resolved_urls.append(sub_url) + else: + if url.startswith(("http://", "https://")): + resolved_urls.append(url) + else: + abs_path = os.path.abspath(url) + if os.path.exists(abs_path): + resolved_urls.append(abs_path) + else: + resolved_urls.append(url) + + results = [] + file_results = [] + for url in resolved_urls: + try: + result = SingleFileParser().call(json.dumps({'url': url}), **kwargs) + results.append(f"# File: {os.path.basename(url)}\n{result}") + file_results.append(result) + except Exception as e: + results.append(f"# Error processing {os.path.basename(url)}: {str(e)}") + if count_tokens(json.dumps(results)) < DEFAULT_MAX_INPUT_TOKENS: + return results + else: + return compress(file_results) + +# @register_tool("file_parser") +class FileParser(BaseTool): + name = "parse_file" + description = "This is a tool that can be used to parse multiple user uploaded local files such as PDF, DOCX, PPTX, TXT, CSV, XLSX, DOC, ZIP, MP4, MP3." + parameters = [ + { + 'name': 'files', + 'type': 'array', + 'array_type': 'string', + 'description': 'The file name of the user uploaded local files to be parsed.', + 'required': True + } + ] + + async def call(self, params, file_root_path): + file_name = params["files"] + outputs = [] + + file_path = [] + omnifile_path = [] + for f_name in file_name: + if '.mp3' not in f_name: + file_path.append(os.path.join(file_root_path, f_name)) + else: + omnifile_path.append(os.path.join(file_root_path, f_name)) + + if len(file_path): + params = {'files': file_path} + response = await file_parser(params) + response = response[:30000] + + parsed_file_content = ' '.join(response) + outputs.extend([f'File token number: {len(parsed_file_content.split())}\nFile content:\n']+response) + + + if len(omnifile_path): + params['files'] = omnifile_path + agent = VideoAgent() + res = await agent.call(params) + + res = json.loads(res) + outputs += res + + return outputs diff --git a/fastapi_app/deep_research/tool_python.py b/fastapi_app/deep_research/tool_python.py new file mode 100644 index 0000000..fb4c7c0 --- /dev/null +++ b/fastapi_app/deep_research/tool_python.py @@ -0,0 +1,150 @@ +import re +from typing import Dict, List, Optional, Union, Any +import json5 +from qwen_agent.tools.base import BaseToolWithFileAccess, register_tool +from qwen_agent.utils.utils import extract_code +from sandbox_fusion import run_code, RunCodeRequest, RunStatus +from requests.exceptions import Timeout +import os +import random +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +CHINESE_CHAR_RE = re.compile(r'[\u4e00-\u9fff]') + + +def has_chinese_chars(data: Any) -> bool: + text = f'{data}' + return bool(CHINESE_CHAR_RE.search(text)) + + +# Array of sandbox fusion endpoints +SANDBOX_FUSION_ENDPOINTS = [] + +# Fallback to single endpoint if environment variable exists +if 'SANDBOX_FUSION_ENDPOINT' in os.environ: + SANDBOX_FUSION_ENDPOINTS = os.environ['SANDBOX_FUSION_ENDPOINT'].split(',') + + +@register_tool('PythonInterpreter', allow_overwrite=True) +class PythonInterpreter(BaseToolWithFileAccess): + name = "PythonInterpreter" + description = 'Execute Python code in a sandboxed environment. Use this to run Python code and get the execution results.\n**Make sure to use print() for any output you want to see in the results.**\nFor code parameters, use placeholders first, and then put the code within XML tags, such as:\n\n{"purpose": , "name": , "arguments": {"code": ""}}\n\nHere is the code.\n\n\n' + + parameters = { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The Python code to execute. Must be provided within XML tags. Remember to use print() statements for any output you want to see.", + } + }, + "required": ["code"], + } + + def __init__(self, cfg: Optional[Dict] = None): + super().__init__(cfg) + # self.summary_mapping = SummaryMapping() + + @property + def args_format(self) -> str: + fmt = self.cfg.get('args_format') + if fmt is None: + if has_chinese_chars([self.name_for_human, self.name, self.description, self.parameters]): + fmt = 'The input for this tool should be a Markdown code block.' + + else: + fmt = 'Enclose the code within triple backticks (`) at the beginning and end of the code.' + return fmt + + def observation(self, tool: dict, tool_dict: dict, tool_results, empty_mode: bool=False, readpage: bool=False, max_observation_length: int=None, tokenizer=None): + print('test') + assert isinstance(tool_results, str), f"result of python code should be str, instead of {type(tool_results)}. {tool_results}" + return tool_results + + @property + def function(self) -> dict: + return { + 'name': self.name, + 'description': self.description, + 'parameters': self.parameters, + } + + def call(self, params, files= None, timeout = 50, **kwargs) -> str: + try: + code=params + last_error = None + for attempt in range(8): + try: + # Randomly sample an endpoint for each attempt + endpoint = random.choice(SANDBOX_FUSION_ENDPOINTS) + print(f"Attempt {attempt + 1}/5 using endpoint: {endpoint}") + + code_result = run_code(RunCodeRequest(code=code, language='python', run_timeout=timeout), max_attempts=1, client_timeout=timeout, endpoint=endpoint) + print("[Python] Code Result", code_result) + result = [] + if code_result.run_result.stdout: + result.append(f"stdout:\n{code_result.run_result.stdout}") + if code_result.run_result.stderr: + result.append(f"stderr:\n{code_result.run_result.stderr}") + if code_result.run_result.execution_time >= timeout-1: + result.append(f"[PythonInterpreter Error] TimeoutError: Execution timed out.") + result = '\n'.join(result) + print('SUCCESS RUNNING TOOL') + return result if result.strip() else 'Finished execution.' + + except Timeout as e: + last_error = f'[Python Interpreter Error] TimeoutError: Execution timed out on endpoint {endpoint}.' + print(f"Timeout on attempt {attempt + 1}: {last_error}") + if attempt == 4: # Last attempt + return last_error + continue + + except Exception as e: + last_error = f'[Python Interpreter Error]: {str(e)} on endpoint {endpoint}' + print(f"Error on attempt {attempt + 1}: {last_error}") + if attempt == 4: # Last attempt + return last_error + continue + + return last_error if last_error else '[Python Interpreter Error]: All attempts failed.' + + except Exception as e: + return f"[Python Interpreter Error]: {str(e)}" + + def call_specific_endpoint(self, params: Union[str, dict], endpoint: str, timeout: Optional[int] = 30, **kwargs) -> tuple: + """Test a specific endpoint directly""" + try: + if type(params) is str: + params = json5.loads(params) + code = params.get('code', '') + if not code: + code = params.get('raw', '') + triple_match = re.search(r'```[^\n]*\n(.+?)```', code, re.DOTALL) + if triple_match: + code = triple_match.group(1) + except Exception: + code = extract_code(params) + + if not code.strip(): + return False, '[Python Interpreter Error]: Empty code.' + + try: + start_time = time.time() + code_result = run_code(RunCodeRequest(code=code, language='python', run_timeout=timeout), + max_attempts=1, client_timeout=timeout, endpoint=endpoint) + end_time = time.time() + + result = [] + if code_result.run_result.stdout: + result.append(f"stdout:\n{code_result.run_result.stdout}") + if code_result.run_result.stderr: + result.append(f"stderr:\n{code_result.run_result.stderr}") + + result = '\n'.join(result) + execution_time = end_time - start_time + return True, result if result.strip() else 'Finished execution.', execution_time + + except Timeout as e: + return False, f'[Python Interpreter Error] TimeoutError: Execution timed out.', None + except Exception as e: + return False, f'[Python Interpreter Error]: {str(e)}', None diff --git a/fastapi_app/deep_research/tool_scholar.py b/fastapi_app/deep_research/tool_scholar.py new file mode 100644 index 0000000..ae021b3 --- /dev/null +++ b/fastapi_app/deep_research/tool_scholar.py @@ -0,0 +1,110 @@ +import os +import json +import requests +from typing import Union, List +from qwen_agent.tools.base import BaseTool, register_tool +from concurrent.futures import ThreadPoolExecutor +import http.client + + +SERPER_KEY=os.environ.get('SERPER_KEY_ID') + + +@register_tool("google_scholar", allow_overwrite=True) +class Scholar(BaseTool): + name = "google_scholar" + description = "Leverage Google Scholar to retrieve relevant information from academic publications. Accepts multiple queries." + parameters = { + "type": "object", + "properties": { + "query": { + "type": "array", + "items": {"type": "string", "description": "The search query."}, + "minItems": 1, + "description": "The list of search queries for Google Scholar." + }, + }, + "required": ["query"], + } + + def google_scholar_with_serp(self, query: str): + conn = http.client.HTTPSConnection("google.serper.dev") + payload = json.dumps({ + "q": query, + }) + headers = { + 'X-API-KEY': SERPER_KEY, + 'Content-Type': 'application/json' + } + for i in range(5): + try: + conn.request("POST", "/scholar", payload, headers) + res = conn.getresponse() + break + except Exception as e: + print(e) + if i == 4: + return f"Google Scholar Timeout, return None, Please try again later." + continue + + + data = res.read() + + results = json.loads(data.decode("utf-8")) + try: + if "organic" not in results: + raise Exception(f"No results found for query: '{query}'. Use a less specific query.") + + web_snippets = list() + idx = 0 + if "organic" in results: + for page in results["organic"]: + idx += 1 + date_published = "" + if "year" in page: + date_published = "\nDate published: " + str(page["year"]) + + publicationInfo = "" + if "publicationInfo" in page: + publicationInfo = "\npublicationInfo: " + page["publicationInfo"] + + snippet = "" + if "snippet" in page: + snippet = "\n" + page["snippet"] + + link_info = "no available link" + if "pdfUrl" in page: + link_info = "pdfUrl: " + page["pdfUrl"] + + citedBy = "" + if "citedBy" in page: + citedBy = "\ncitedBy: " + str(page["citedBy"]) + + redacted_version = f"{idx}. [{page['title']}]({link_info}){publicationInfo}{date_published}{citedBy}\n{snippet}" + + redacted_version = redacted_version.replace("Your browser can't play this video.", "") + web_snippets.append(redacted_version) + + content = f"A Google scholar for '{query}' found {len(web_snippets)} results:\n\n## Scholar Results\n" + "\n\n".join(web_snippets) + return content + except: + return f"No results found for '{query}'. Try with a more general query." + + + def call(self, params: Union[str, dict], **kwargs) -> str: + # assert GOOGLE_SEARCH_KEY is not None, "Please set the IDEALAB_SEARCH_KEY environment variable." + try: + params = self._verify_json_format_args(params) + query = params["query"] + except: + return "[google_scholar] Invalid request format: Input must be a JSON object containing 'query' field" + + if isinstance(query, str): + response = self.google_scholar_with_serp(query) + else: + assert isinstance(query, List) + with ThreadPoolExecutor(max_workers=3) as executor: + + response = list(executor.map(self.google_scholar_with_serp, query)) + response = "\n=======\n".join(response) + return response diff --git a/fastapi_app/deep_research/tool_search.py b/fastapi_app/deep_research/tool_search.py new file mode 100644 index 0000000..e49c576 --- /dev/null +++ b/fastapi_app/deep_research/tool_search.py @@ -0,0 +1,133 @@ +import json +from concurrent.futures import ThreadPoolExecutor +from typing import List, Union +import requests +from qwen_agent.tools.base import BaseTool, register_tool +import asyncio +from typing import Dict, List, Optional, Union +import uuid +import http.client +import json + +import os + + +def get_serper_key(): + """Dynamically get SERPER_KEY from environment""" + return os.environ.get('SERPER_KEY_ID') + + +@register_tool("search", allow_overwrite=True) +class Search(BaseTool): + name = "search" + description = "Performs batched web searches: supply an array 'query'; the tool retrieves the top 10 results for each query in one call." + parameters = { + "type": "object", + "properties": { + "query": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Array of query strings. Include multiple complementary search queries in a single call." + }, + }, + "required": ["query"], + } + + def __init__(self, cfg: Optional[dict] = None): + super().__init__(cfg) + def google_search_with_serp(self, query: str): + def contains_chinese_basic(text: str) -> bool: + return any('\u4E00' <= char <= '\u9FFF' for char in text) + conn = http.client.HTTPSConnection("google.serper.dev") + if contains_chinese_basic(query): + payload = json.dumps({ + "q": query, + "location": "China", + "gl": "cn", + "hl": "zh-cn" + }) + + else: + payload = json.dumps({ + "q": query, + "location": "United States", + "gl": "us", + "hl": "en" + }) + headers = { + 'X-API-KEY': get_serper_key(), + 'Content-Type': 'application/json' + } + + + for i in range(5): + try: + conn.request("POST", "/search", payload, headers) + res = conn.getresponse() + break + except Exception as e: + print(e) + if i == 4: + return f"Google search Timeout, return None, Please try again later." + continue + + data = res.read() + results = json.loads(data.decode("utf-8")) + + try: + if "organic" not in results: + raise Exception(f"No results found for query: '{query}'. Use a less specific query.") + + web_snippets = list() + idx = 0 + if "organic" in results: + for page in results["organic"]: + idx += 1 + date_published = "" + if "date" in page: + date_published = "\nDate published: " + page["date"] + + source = "" + if "source" in page: + source = "\nSource: " + page["source"] + + snippet = "" + if "snippet" in page: + snippet = "\n" + page["snippet"] + + redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" + redacted_version = redacted_version.replace("Your browser can't play this video.", "") + web_snippets.append(redacted_version) + + content = f"A Google search for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n" + "\n\n".join(web_snippets) + return content + except: + return f"No results found for '{query}'. Try with a more general query." + + + + def search_with_serp(self, query: str): + result = self.google_search_with_serp(query) + return result + + def call(self, params: Union[str, dict], **kwargs) -> str: + try: + query = params["query"] + except: + return "[Search] Invalid request format: Input must be a JSON object containing 'query' field" + + if isinstance(query, str): + # 单个查询 + response = self.search_with_serp(query) + else: + # 多个查询 + assert isinstance(query, List) + responses = [] + for q in query: + responses.append(self.search_with_serp(q)) + response = "\n=======\n".join(responses) + + return response + diff --git a/fastapi_app/deep_research/tool_visit.py b/fastapi_app/deep_research/tool_visit.py new file mode 100644 index 0000000..eb08b29 --- /dev/null +++ b/fastapi_app/deep_research/tool_visit.py @@ -0,0 +1,269 @@ +import json +import os +import signal +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Union +import requests +from qwen_agent.tools.base import BaseTool, register_tool +from .prompt import EXTRACTOR_PROMPT +from openai import OpenAI +import random +from urllib.parse import urlparse, unquote +import time +from transformers import AutoTokenizer +import tiktoken + +VISIT_SERVER_TIMEOUT = int(os.getenv("VISIT_SERVER_TIMEOUT", 200)) +WEBCONTENT_MAXLENGTH = int(os.getenv("WEBCONTENT_MAXLENGTH", 150000)) + + +def get_jina_api_keys(): + """Dynamically get JINA_API_KEYS from environment""" + return os.getenv("JINA_API_KEYS", "") + + +@staticmethod +def truncate_to_tokens(text: str, max_tokens: int = 95000) -> str: + encoding = tiktoken.get_encoding("cl100k_base") + + tokens = encoding.encode(text) + if len(tokens) <= max_tokens: + return text + + truncated_tokens = tokens[:max_tokens] + return encoding.decode(truncated_tokens) + +OSS_JSON_FORMAT = """# Response Formats +## visit_content +{"properties":{"rational":{"type":"string","description":"Locate the **specific sections/data** directly related to the user's goal within the webpage content"},"evidence":{"type":"string","description":"Identify and extract the **most relevant information** from the content, never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.","summary":{"type":"string","description":"Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal."}}}}""" + + +@register_tool('visit', allow_overwrite=True) +class Visit(BaseTool): + # The `description` tells the agent the functionality of this tool. + name = 'visit' + description = 'Visit webpage(s) and return the summary of the content.' + # The `parameters` tell the agent what input parameters the tool has. + parameters = { + "type": "object", + "properties": { + "url": { + "type": ["string", "array"], + "items": { + "type": "string" + }, + "minItems": 1, + "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs." + }, + "goal": { + "type": "string", + "description": "The goal of the visit for webpage(s)." + } + }, + "required": ["url", "goal"] + } + # The `call` method is the main function of the tool. + def call(self, params: Union[str, dict], **kwargs) -> str: + try: + url = params["url"] + goal = params["goal"] + except: + return "[Visit] Invalid request format: Input must be a JSON object containing 'url' and 'goal' fields" + + start_time = time.time() + + # Create log folder if it doesn't exist + log_folder = "log" + os.makedirs(log_folder, exist_ok=True) + + if isinstance(url, str): + response = self.readpage_jina(url, goal) + else: + response = [] + assert isinstance(url, List) + start_time = time.time() + for u in url: + if time.time() - start_time > 900: + cur_response = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) + cur_response += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" + cur_response += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" + else: + try: + cur_response = self.readpage_jina(u, goal) + except Exception as e: + cur_response = f"Error fetching {u}: {str(e)}" + response.append(cur_response) + response = "\n=======\n".join(response) + + print(f'Summary Length {len(response)}; Summary Content {response}') + return response.strip() + + def call_server(self, msgs, max_retries=2): + api_key = os.environ.get("API_KEY") + url_llm = os.environ.get("API_BASE") + model_name = os.environ.get("SUMMARY_MODEL_NAME", "") + client = OpenAI( + api_key=api_key, + base_url=url_llm, + ) + for attempt in range(max_retries): + try: + chat_response = client.chat.completions.create( + model=model_name, + messages=msgs, + temperature=0.7 + ) + content = chat_response.choices[0].message.content + if content: + try: + json.loads(content) + except: + # extract json from string + left = content.find('{') + right = content.rfind('}') + if left != -1 and right != -1 and left <= right: + content = content[left:right+1] + return content + except Exception as e: + # print(e) + if attempt == (max_retries - 1): + return "" + continue + + + def jina_readpage(self, url: str) -> str: + """ + Read webpage content using Jina service. + + Args: + url: The URL to read + goal: The goal/purpose of reading the page + + Returns: + str: The webpage content or error message + """ + max_retries = 3 + timeout = 50 + + jina_keys = get_jina_api_keys() + if not jina_keys: + print("[visit] JINA_API_KEYS not configured, skipping Jina service") + return "[visit] Failed to read page." + + for attempt in range(max_retries): + headers = { + "Authorization": f"Bearer {jina_keys}", + } + try: + response = requests.get( + f"https://r.jina.ai/{url}", + headers=headers, + timeout=timeout + ) + if response.status_code == 200: + webpage_content = response.text + return webpage_content + elif response.status_code == 401: + print(f"[visit] Jina API authentication failed (401). Please check your JINA_API_KEYS.") + return "[visit] Failed to read page." + else: + print(f"[visit] Jina API returned status {response.status_code}: {response.text}") + if attempt == max_retries - 1: + return "[visit] Failed to read page." + except Exception as e: + print(f"[visit] Jina API error on attempt {attempt + 1}/{max_retries}: {str(e)}") + time.sleep(0.5) + if attempt == max_retries - 1: + return "[visit] Failed to read page." + + return "[visit] Failed to read page." + + def html_readpage_jina(self, url: str) -> str: + max_attempts = 8 + for attempt in range(max_attempts): + content = self.jina_readpage(url) + service = "jina" + print(service) + if content and not content.startswith("[visit] Failed to read page.") and content != "[visit] Empty content." and not content.startswith("[document_parser]"): + return content + return "[visit] Failed to read page." + + def readpage_jina(self, url: str, goal: str) -> str: + """ + Attempt to read webpage content by alternating between jina and aidata services. + + Args: + url: The URL to read + goal: The goal/purpose of reading the page + + Returns: + str: The webpage content or error message + """ + + summary_page_func = self.call_server + max_retries = int(os.getenv('VISIT_SERVER_MAX_RETRIES', 1)) + + content = self.html_readpage_jina(url) + + if content and not content.startswith("[visit] Failed to read page.") and content != "[visit] Empty content." and not content.startswith("[document_parser]"): + content = truncate_to_tokens(content, max_tokens=95000) + messages = [{"role":"user","content": EXTRACTOR_PROMPT.format(webpage_content=content, goal=goal)}] + parse_retry_times = 0 + raw = summary_page_func(messages, max_retries=max_retries) + summary_retries = 3 + while len(raw) < 10 and summary_retries >= 0: + truncate_length = int(0.7 * len(content)) if summary_retries > 0 else 25000 + status_msg = ( + f"[visit] Summary url[{url}] " + f"attempt {3 - summary_retries + 1}/3, " + f"content length: {len(content)}, " + f"truncating to {truncate_length} chars" + ) if summary_retries > 0 else ( + f"[visit] Summary url[{url}] failed after 3 attempts, " + f"final truncation to 25000 chars" + ) + print(status_msg) + content = content[:truncate_length] + extraction_prompt = EXTRACTOR_PROMPT.format( + webpage_content=content, + goal=goal + ) + messages = [{"role": "user", "content": extraction_prompt}] + raw = summary_page_func(messages, max_retries=max_retries) + summary_retries -= 1 + + parse_retry_times = 0 + if isinstance(raw, str): + raw = raw.replace("```json", "").replace("```", "").strip() + while parse_retry_times < 3: + try: + raw = json.loads(raw) + break + except: + raw = summary_page_func(messages, max_retries=max_retries) + parse_retry_times += 1 + + if parse_retry_times >= 3: + useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) + useful_information += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" + useful_information += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" + else: + useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) + useful_information += "Evidence in page: \n" + str(raw["evidence"]) + "\n\n" + useful_information += "Summary: \n" + str(raw["summary"]) + "\n\n" + + if len(useful_information) < 10 and summary_retries < 0: + print("[visit] Could not generate valid summary after maximum retries") + useful_information = "[visit] Failed to read page" + + return useful_information + + # If no valid content was obtained after all retries + else: + useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) + useful_information += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" + useful_information += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" + return useful_information + + \ No newline at end of file diff --git a/fastapi_app/dependencies/auth.py b/fastapi_app/dependencies/auth.py index fd3e81a..c5283d6 100644 --- a/fastapi_app/dependencies/auth.py +++ b/fastapi_app/dependencies/auth.py @@ -8,10 +8,15 @@ from typing import Any, Optional from fastapi import Header, HTTPException +from workflow_engine.logger import get_logger + +log = get_logger(__name__) try: from supabase import create_client, Client -except Exception: + log.info("Supabase 库导入成功") +except Exception as e: + log.warning(f"Supabase 库导入失败: {e}") create_client = None # type: ignore[misc, assignment] Client = Any # type: ignore[misc, assignment] @@ -124,9 +129,10 @@ async def get_current_user(authorization: Optional[str] = Header(None)) -> AuthU except HTTPException: raise except Exception as e: + log.error(f"Token validation failed: {type(e).__name__}: {str(e)}", exc_info=True) raise HTTPException( status_code=401, - detail=f"Token validation failed: {str(e)}" + detail="Token validation failed" ) diff --git a/fastapi_app/embedding_server.py b/fastapi_app/embedding_server.py new file mode 100644 index 0000000..b38e5fb --- /dev/null +++ b/fastapi_app/embedding_server.py @@ -0,0 +1,162 @@ +""" +本地 Embedding 服务:加载 Octen/Octen-Embedding-0.6B,提供 OpenAI 兼容的 POST /v1/embeddings。 +可单独启动:uvicorn fastapi_app.embedding_server:app --host 127.0.0.1 --port 17997 +或由主后端在 USE_LOCAL_EMBEDDING=1 时自动拉起。 +""" +from __future__ import annotations + +import os +from contextlib import asynccontextmanager +from typing import List, Union + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel, Field + +EMBEDDING_MODEL_NAME = "Octen-Embedding-0.6B" +HF_MODEL_ID = "Octen/Octen-Embedding-0.6B" + + +def _pick_device() -> str: + """通过 nvidia-smi 查询空闲显存最多的 GPU,避免触碰已损坏的 CUDA context。""" + import subprocess + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index,memory.free,memory.total", + "--format=csv,noheader,nounits"], + capture_output=True, text=True, timeout=10, + ) + if result.returncode != 0: + print(f"[embedding_server] nvidia-smi 失败,回退到 CPU") + return "cpu" + best_idx, best_free = -1, 0 + for line in result.stdout.strip().splitlines(): + parts = [p.strip() for p in line.split(",")] + idx, free, total = int(parts[0]), int(parts[1]), int(parts[2]) + print(f"[embedding_server] GPU {idx}: 空闲 {free} MB / 总共 {total} MB") + if free > best_free: + best_free = free + best_idx = idx + if best_idx >= 0 and best_free > 512: # 至少 512 MB 空闲 + device = f"cuda:{best_idx}" + print(f"[embedding_server] 选择 {device}(空闲 {best_free} MB)") + return device + print("[embedding_server] 所有 GPU 显存不足,回退到 CPU") + return "cpu" + except Exception as e: + print(f"[embedding_server] 查询 GPU 失败: {e},回退到 CPU") + return "cpu" + + +def _get_embedder(): + """懒加载,首次请求时下载并加载模型,自动选择空闲 GPU。""" + if _get_embedder._model is None: + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise RuntimeError( + "请安装 sentence-transformers: pip install sentence-transformers" + ) + device = _pick_device() + _get_embedder._model = SentenceTransformer(HF_MODEL_ID, device=device) + return _get_embedder._model + + +_get_embedder._model = None + + +class EmbeddingRequest(BaseModel): + model: str = Field(default=EMBEDDING_MODEL_NAME, description="模型名,可忽略") + input: Union[str, List[str]] = Field(..., description="单条文本或文本列表") + + +class EmbeddingItem(BaseModel): + object: str = "embedding" + embedding: List[float] + index: int + + +class EmbeddingResponse(BaseModel): + object: str = "list" + data: List[EmbeddingItem] + model: str = EMBEDDING_MODEL_NAME + usage: dict = Field(default_factory=lambda: {"prompt_tokens": 0, "total_tokens": 0}) + + +def _ensure_model_loaded(): + """启动时检查:已缓存则 log 提示,未缓存则下载并加载。""" + try: + from huggingface_hub import snapshot_download + snapshot_download(repo_id=HF_MODEL_ID, local_files_only=True) + print(f"[embedding_server] 模型已缓存,正在加载 {HF_MODEL_ID} ...") + except Exception: + print(f"[embedding_server] 模型未缓存,正在下载并加载 {HF_MODEL_ID}(首次较慢)...") + _get_embedder() + print(f"[embedding_server] {EMBEDDING_MODEL_NAME} 已就绪。") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 启动时检查/下载并加载模型,不再依赖远程 embedding + try: + _ensure_model_loaded() + except Exception as e: + print(f"[embedding_server] 加载失败: {e}") + raise + yield + if _get_embedder._model is not None: + try: + del _get_embedder._model + _get_embedder._model = None + except Exception: + pass + + +app = FastAPI( + title="Local Embedding (Octen-Embedding-0.6B)", + version="0.1.0", + lifespan=lifespan, +) + + +@app.post("/v1/embeddings", response_model=EmbeddingResponse) +async def embeddings(req: EmbeddingRequest): + """OpenAI 兼容的 embedding 接口。""" + if isinstance(req.input, str): + texts = [req.input] + else: + texts = list(req.input) + if not texts: + raise HTTPException(status_code=400, detail="input 不能为空") + + # 限制单次 batch 大小,避免 OOM + max_batch = int(os.getenv("EMBEDDING_MAX_BATCH", "32")) + if len(texts) > max_batch: + raise HTTPException( + status_code=400, + detail=f"单次最多 {max_batch} 条,当前 {len(texts)} 条", + ) + + try: + model = _get_embedder() + # 换行可能影响效果,与 VectorStoreManager 行为一致 + texts_clean = [t.replace("\n", " ").strip() or " " for t in texts] + emb = model.encode( + texts_clean, + normalize_embeddings=True, + show_progress_bar=False, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + if emb.ndim == 1: + emb = emb.reshape(1, -1) + data = [ + EmbeddingItem(embedding=emb[i].tolist(), index=i) + for i in range(len(texts)) + ] + return EmbeddingResponse(data=data) + + +@app.get("/health") +async def health(): + return {"status": "ok", "model": EMBEDDING_MODEL_NAME} diff --git a/fastapi_app/fireredtts_manager.py b/fastapi_app/fireredtts_manager.py new file mode 100644 index 0000000..7ba7ccc --- /dev/null +++ b/fastapi_app/fireredtts_manager.py @@ -0,0 +1,215 @@ +""" +FireRedTTS2 Manager +懒加载 + 自动卸载 +""" +import os +import time +import threading +import torch + +REPO_ID = "FireRedTeam/FireRedTTS2" +IDLE_TIMEOUT = int(os.getenv("TTS_IDLE_TIMEOUT", "300")) + +_model = None +_device = None +_model_path = None +_last_used = None +_lock = threading.Lock() +_unload_timer = None + + +def _pick_device(): + """选择最优GPU""" + if not torch.cuda.is_available(): + return "cpu" + + try: + import subprocess + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index,memory.free", "--format=csv,noheader,nounits"], + capture_output=True, text=True, timeout=5 + ) + lines = [l.strip() for l in result.stdout.strip().split("\n") if l.strip()] + if not lines: + return "cuda:0" + + gpu_mem = [] + for line in lines: + parts = line.split(",") + if len(parts) == 2: + idx, mem = parts[0].strip(), parts[1].strip() + gpu_mem.append((int(idx), int(mem))) + + if gpu_mem: + best_gpu = max(gpu_mem, key=lambda x: x[1]) + return f"cuda:{best_gpu[0]}" + except Exception: + pass + + return "cuda:0" + + +def _schedule_unload(): + """调度自动卸载""" + global _unload_timer + if _unload_timer is not None: + _unload_timer.cancel() + + def _unload(): + global _model, _device, _last_used + with _lock: + if _last_used and (time.time() - _last_used >= IDLE_TIMEOUT): + print(f"[TTS] 空闲 {IDLE_TIMEOUT}s,卸载模型") + _model = None + _device = None + _last_used = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + _unload_timer = threading.Timer(IDLE_TIMEOUT, _unload) + _unload_timer.daemon = True + _unload_timer.start() + + +def _load_model(): + """懒加载模型""" + global _model, _device, _model_path, _last_used + + if _model is not None: + _last_used = time.time() + _schedule_unload() + return _model, _device, _model_path + + print(f"[TTS] 加载 FireRedTTS2 模型: {REPO_ID}") + + try: + from fireredtts2.fireredtts2 import FireRedTTS2 + except ImportError as e: + raise RuntimeError(f"fireredtts2 未安装: {e}\n运行: pip install fireredtts2") + + # 下载模型到 HuggingFace 缓存 + try: + from huggingface_hub import snapshot_download + _model_path = snapshot_download( + repo_id=REPO_ID, + resume_download=True, + local_files_only=False + ) + print(f"[TTS] 模型已下载到: {_model_path}") + except Exception as e: + print(f"[TTS] 模型下载失败: {e}") + raise + + _device = _pick_device() + print(f"[TTS] 使用设备: {_device}") + + try: + _model = FireRedTTS2( + pretrained_dir=_model_path, + gen_type="dialogue", + device=_device, + ) + except Exception as e: + print(f"[TTS] 加载失败: {e}") + raise + + _last_used = time.time() + _schedule_unload() + print(f"[TTS] 模型加载完成") + + return _model, _device, _model_path + + +def generate_speech(text: str, voice_name: str = "S1", temperature: float = 0.9) -> bytes: + """ + 生成语音,返回 WAV 音频字节 + + Args: + text: 文本内容,必须包含说话人标签格式 "[S1]text\n[S2]text" + voice_name: 未使用(保留参数兼容性) + temperature: 生成温度 + + Returns: + WAV 格式音频字节 (24kHz, 16-bit, mono) + """ + with _lock: + model, device, model_path = _load_model() + + # 解析文本为对话列表,并分割过长的行 + text_list = [] + max_line_len = 200 # 每行最多200字符 + for line in text.strip().split("\n"): + line = line.strip() + if line and (line.startswith("[S1]") or line.startswith("[S2]")): + speaker = line[:4] # [S1] or [S2] + content = line[4:].strip() + # 分割过长内容 + if len(content) <= max_line_len: + text_list.append(line) + else: + # 按句子分割 + import re + sentences = re.split(r'([。!?.!?])', content) + current = "" + for i in range(0, len(sentences), 2): + sent = sentences[i] + punct = sentences[i+1] if i+1 < len(sentences) else "" + if len(current) + len(sent) + len(punct) <= max_line_len: + current += sent + punct + else: + if current: + text_list.append(f"{speaker}{current}") + current = sent + punct + if current: + text_list.append(f"{speaker}{current}") + + if not text_list: + raise ValueError("No valid dialogue lines found in text") + + print(f"[TTS] 生成 {len(text_list)} 行对话,总长度: {sum(len(t) for t in text_list)}") + + # 生成音频(使用模型默认音色) + import torchaudio + import io + audio = model.generate_dialogue( + text_list=text_list, + temperature=temperature, + topk=30, + ) + + # 保存为 WAV 字节 + buf = io.BytesIO() + torchaudio.save(buf, audio, 24000, format="wav") + return buf.getvalue() + + +def is_available() -> bool: + """检查 FireRedTTS2 是否可用""" + try: + from fireredtts2.fireredtts2 import FireRedTTS2 + return True + except ImportError: + return False + + +def check_and_download_model(): + """启动时检查并自动下载 FireRedTTS2 模型""" + # 检查并安装 fireredtts2 + try: + import fireredtts2 + print(f"[TTS] fireredtts2 已安装") + except ImportError: + print(f"[TTS] 正在安装 fireredtts2...") + import subprocess + import sys + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", "fireredtts2"]) + print(f"[TTS] fireredtts2 安装完成") + except Exception as e: + print(f"[TTS] fireredtts2 安装失败: {e}") + return + + print(f"[TTS] 检查模型: {REPO_ID}") + print(f"[TTS] 首次使用时将自动从 HuggingFace 下载或使用本地缓存") + + diff --git a/fastapi_app/main.py b/fastapi_app/main.py index 1222a4d..3bbbb45 100644 --- a/fastapi_app/main.py +++ b/fastapi_app/main.py @@ -12,22 +12,39 @@ from dotenv import load_dotenv _root = Path(__file__).resolve().parent.parent load_dotenv(_root / "fastapi_app" / ".env") - load_dotenv(_root / ".env") except ImportError: pass +from workflow_engine.logger import get_logger + +log = get_logger(__name__) + +# 启动时检查 Supabase 配置 +_supabase_url = os.getenv("SUPABASE_URL") +_supabase_anon = os.getenv("SUPABASE_ANON_KEY") +_supabase_service = os.getenv("SUPABASE_SERVICE_ROLE_KEY") +if _supabase_url and _supabase_anon: + log.info(f"Supabase 已配置: URL={_supabase_url[:30]}..., ANON_KEY={'已设置' if _supabase_anon else '未设置'}, SERVICE_KEY={'已设置' if _supabase_service else '未设置'}") +else: + log.info(f"Supabase 未配置: URL={'已设置' if _supabase_url else '未设置'}, ANON_KEY={'已设置' if _supabase_anon else '未设置'}") + + from urllib.parse import unquote from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse -from fastapi_app.routers import kb, kb_embedding, files, paper2drawio, paper2ppt +from fastapi_app.routers import kb, kb_embedding, files, paper2drawio, paper2ppt, auth, data_insight from fastapi_app.middleware.api_key import APIKeyMiddleware -from dataflow_agent.utils import get_project_root +from fastapi_app.middleware.logging import LoggingMiddleware +from workflow_engine.utils import get_project_root + +# 导入workflow模块以触发所有workflow注册 +from workflow_engine import workflow # 本地 Embedding 服务端口(Octen-Embedding-0.6B) -LOCAL_EMBEDDING_PORT = 17997 +LOCAL_EMBEDDING_PORT = 26210 LOCAL_EMBEDDING_URL = f"http://127.0.0.1:{LOCAL_EMBEDDING_PORT}/v1/embeddings" @@ -37,36 +54,61 @@ async def _lifespan(app: FastAPI): use_local = os.getenv("USE_LOCAL_EMBEDDING", "1").strip().lower() in ("1", "true", "yes") proc = None if use_local: + # 检查 embedding 服务是否已在运行(--reload 场景下避免重复拉起) + _already_running = False try: - proc = subprocess.Popen( - [ - sys.executable, "-m", "uvicorn", - "fastapi_app.embedding_server:app", - "--host", "127.0.0.1", - "--port", str(LOCAL_EMBEDDING_PORT), - ], - cwd=str(Path(__file__).resolve().parent.parent), - stdout=None, - stderr=None, - ) - os.environ["EMBEDDING_API_URL"] = LOCAL_EMBEDDING_URL - os.environ["EMBEDDING_MODEL"] = "Octen-Embedding-0.6B" - for _ in range(60): - time.sleep(0.5) - if proc.poll() is not None: - print("[WARN] 本地 Embedding 子进程已退出,请检查上方日志") - break - try: - import urllib.request - urllib.request.urlopen(f"http://127.0.0.1:{LOCAL_EMBEDDING_PORT}/health", timeout=1) - print(f"[INFO] 本地 Embedding 已就绪 (Octen-Embedding-0.6B) @ {LOCAL_EMBEDDING_URL}") - break - except Exception: - continue - else: - print("[WARN] 本地 Embedding 启动超时,请检查 sentence-transformers 是否已安装及上方日志") + import urllib.request + urllib.request.urlopen(f"http://127.0.0.1:{LOCAL_EMBEDDING_PORT}/health", timeout=2) + _already_running = True + log.info(f"本地 Embedding 已在运行,复用 @ {LOCAL_EMBEDDING_URL}") except Exception as e: - print(f"[WARN] 启动本地 Embedding 失败: {e}") + log.debug(f"本地 Embedding 健康检查失败: {e}") + if not _already_running: + try: + proc = subprocess.Popen( + [ + sys.executable, "-m", "uvicorn", + "fastapi_app.embedding_server:app", + "--host", "127.0.0.1", + "--port", str(LOCAL_EMBEDDING_PORT), + ], + cwd=str(Path(__file__).resolve().parent.parent), + stdout=None, + stderr=None, + ) + for _ in range(60): + time.sleep(0.5) + if proc.poll() is not None: + log.warning("本地 Embedding 子进程已退出,请检查上方日志") + break + try: + import urllib.request + urllib.request.urlopen(f"http://127.0.0.1:{LOCAL_EMBEDDING_PORT}/health", timeout=1) + log.info(f"本地 Embedding 已就绪 (Octen-Embedding-0.6B) @ {LOCAL_EMBEDDING_URL}") + break + except Exception: + continue + else: + log.warning("本地 Embedding 启动超时,请检查 sentence-transformers 是否已安装及上方日志") + except Exception as e: + log.warning(f"启动本地 Embedding 失败: {e}") + os.environ["EMBEDDING_API_URL"] = LOCAL_EMBEDDING_URL + os.environ["EMBEDDING_MODEL"] = "Octen-Embedding-0.6B" + + # 检查 TTS 模型 + use_local_tts = os.getenv("USE_LOCAL_TTS", "0").strip().lower() in ("1", "true", "yes") + tts_engine = os.getenv("TTS_ENGINE", "qwen").strip().lower() + if use_local_tts: + try: + if tts_engine == "qwen": + from fastapi_app.qwen_tts_manager import check_and_download_model + check_and_download_model() + elif tts_engine == "firered": + from fastapi_app.fireredtts_manager import check_and_download_model + check_and_download_model() + except Exception as e: + log.warning(f"TTS 模型检查失败: {e}") + yield if proc is not None and proc.poll() is None: proc.terminate() @@ -100,6 +142,9 @@ def create_app() -> FastAPI: allow_headers=["*"], ) + # Logging middleware (first to capture all requests) + app.add_middleware(LoggingMiddleware) + # API key verification for /api/* routes app.add_middleware(APIKeyMiddleware) @@ -109,6 +154,8 @@ def create_app() -> FastAPI: app.include_router(files.router, prefix="/api/v1", tags=["Files"]) app.include_router(paper2drawio.router, prefix="/api/v1", tags=["Paper2Drawio"]) app.include_router(paper2ppt.router, prefix="/api/v1", tags=["Paper2PPT"]) + app.include_router(auth.router, prefix="/api/v1", tags=["Auth"]) + app.include_router(data_insight.router, prefix="/api/v1", tags=["Data Insight"]) # 静态文件:/outputs 下的文件(兼容 URL 中 %40 与 磁盘 @ 两种路径) project_root = get_project_root() @@ -131,17 +178,18 @@ async def serve_outputs(path: str): if file_path.suffix.lower() == ".pdf": resp.headers["Content-Disposition"] = "inline" return resp - except Exception: + except Exception as e: + log.debug(f"文件路径解析失败: {candidate}, 错误: {e}") continue raise HTTPException(status_code=404, detail="Not found") - print(f"[INFO] Serving /outputs from {outputs_dir}") + log.info(f"Serving /outputs from {outputs_dir}") @app.get("/health") async def health_check(): return {"status": "ok"} - print("[INFO] 后端已连接 / Backend ready") + log.info("后端已连接 / Backend ready") return app diff --git a/fastapi_app/middleware/api_key.py b/fastapi_app/middleware/api_key.py index a10167d..82b732c 100644 --- a/fastapi_app/middleware/api_key.py +++ b/fastapi_app/middleware/api_key.py @@ -30,6 +30,7 @@ async def workflow(_: None = Depends(verify_api_key)): "/docs", "/openapi.json", "/redoc", + "/api/v1/auth/config", } # Path prefixes that don't require API key @@ -48,8 +49,13 @@ class APIKeyMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): path = request.url.path + # Skip OPTIONS requests (CORS preflight) + if request.method == "OPTIONS": + return await call_next(request) + # Skip excluded paths if path in EXCLUDED_PATHS: + print(f"[DEBUG] Path {path} in EXCLUDED_PATHS, skipping API key check") return await call_next(request) # Skip excluded prefixes diff --git a/fastapi_app/middleware/logging.py b/fastapi_app/middleware/logging.py new file mode 100644 index 0000000..f474274 --- /dev/null +++ b/fastapi_app/middleware/logging.py @@ -0,0 +1,42 @@ +import time +import uuid +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from workflow_engine.logger import get_logger, set_request_context + +log = get_logger(__name__) + + +class LoggingMiddleware(BaseHTTPMiddleware): + """Middleware for request logging and context tracking.""" + + async def dispatch(self, request: Request, call_next): + request_id = str(uuid.uuid4()) + user_email = None + user_id = None + + try: + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + import base64 + import json + token = auth_header.split(" ", 1)[1] + parts = token.split(".") + if len(parts) >= 2: + payload_b64 = parts[1] + payload_b64 += "=" * (4 - len(payload_b64) % 4) + decoded = json.loads(base64.urlsafe_b64decode(payload_b64)) + user_email = decoded.get("email") + user_id = decoded.get("sub") or decoded.get("user_id") + except Exception: + pass + + set_request_context(request_id=request_id, user_id=user_id, user_email=user_email) + log.info(f"{request.method} {request.url.path}") + + start_time = time.time() + response = await call_next(request) + duration = time.time() - start_time + + log.info(f"{request.method} {request.url.path} - {response.status_code} ({duration:.3f}s)") + return response diff --git a/fastapi_app/notebook_paths.py b/fastapi_app/notebook_paths.py index 327f571..bbdf6af 100644 --- a/fastapi_app/notebook_paths.py +++ b/fastapi_app/notebook_paths.py @@ -3,7 +3,7 @@ All path construction for the new directory layout lives here: - outputs/{safe_title}_{notebook_id}/ + outputs/{user_id}/{safe_title}_{notebook_id}/ ├── sources/{source_stem}/original/ │ /mineru/ │ /markdown/ @@ -22,13 +22,60 @@ from pathlib import Path from typing import Optional -from dataflow_agent.utils import get_project_root +from workflow_engine.utils import get_project_root # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- +def _sanitize_user_id(user_id: Optional[str], max_len: int = 64) -> str: + """ + Turn a user_id (often an email like 'user@example.com') into a + filesystem-safe directory name. + + Handles special characters that may appear in emails: + - @ -> _at_ + - Whitespace -> _ + - Slashes, backslashes -> _ + - Other unsafe characters -> removed + + Examples: + '765973346@qq.com' -> '765973346_at_qq.com' + 'user+tag@example.com' -> 'user_tag_at_example.com' + """ + user_id = (user_id or "").strip() + if not user_id: + return "local" + + # Normalize unicode + user_id = unicodedata.normalize("NFC", user_id) + + # Replace @ with _at_ to preserve email structure readability + user_id = user_id.replace("@", "_at_") + + # Replace slashes and backslashes + user_id = user_id.replace("/", "_").replace("\\", "_") + + # Replace whitespace runs with underscore + user_id = re.sub(r"\s+", "_", user_id) + + # Keep only safe chars: word chars (a-z, A-Z, 0-9, _), hyphen, dot + # Note: We keep dots for email domains like 'qq.com' + user_id = re.sub(r"[^\w\-.]", "_", user_id, flags=re.ASCII) + + # Collapse multiple underscores + user_id = re.sub(r"_+", "_", user_id) + + # Strip leading/trailing special chars + user_id = user_id.strip("_.- ") + + if not user_id: + return "local" + + return user_id[:max_len] + + def _sanitize_dir_name(title: str, max_len: int = 60) -> str: """ Turn an arbitrary notebook title into a filesystem-safe directory component. @@ -54,13 +101,12 @@ def resolve_notebook_title( user_id: Optional[str] = None, ) -> str: """ - Look up the notebook title from local JSON or Supabase. + Look up the notebook title from local JSON. Returns the title string, or empty string if not found. """ - # 1) Try local JSON root = get_project_root() - safe_uid = (user_id or "default").replace("/", "_").replace("\\", "_")[:64] - local_path = root / "outputs" / "kb_data" / "_notebooks" / f"{safe_uid}.json" + safe_uid = _sanitize_user_id(user_id) + local_path = root / "outputs" / safe_uid / "_notebooks.json" if local_path.exists(): try: data = json.loads(local_path.read_text(encoding="utf-8")) @@ -70,19 +116,6 @@ def resolve_notebook_title( return nb.get("name") or nb.get("title") or "" except Exception: pass - - # 2) Try Supabase - try: - from fastapi_app.dependencies.auth import get_supabase_admin_client - sb = get_supabase_admin_client() - if sb: - r = sb.table("knowledge_bases").select("name").eq("id", notebook_id).limit(1).execute() - rows = (r.data or []) if hasattr(r, "data") else [] - if rows: - return rows[0].get("name") or "" - except Exception: - pass - return "" @@ -101,6 +134,7 @@ def __init__( project_root: Optional[Path] = None, ): self._notebook_id = notebook_id + self._user_id = user_id or "local" self._project_root = project_root or get_project_root() # Resolve title if not provided @@ -123,30 +157,48 @@ def notebook_dir_name(self) -> str: @property def root(self) -> Path: - """outputs/{title}_{id}/ — with fallback scan for existing dirs.""" + """outputs/{user_id}/{title}_{id}/ — with fallback scan for existing dirs.""" if self._resolved_root is not None: return self._resolved_root - candidate = self._project_root / "outputs" / self.notebook_dir_name + safe_uid = _sanitize_user_id(self._user_id) + candidate = self._project_root / "outputs" / safe_uid / self.notebook_dir_name if candidate.exists(): self._resolved_root = candidate return candidate - # Fallback: scan outputs/ for any dir ending with _{notebook_id} + # Fallback: scan outputs/{user_id}/ or outputs/ for any dir ending with _{notebook_id} self._resolved_root = self._find_existing_root() or candidate return self._resolved_root def _find_existing_root(self) -> Optional[Path]: - """Scan outputs/ for a directory whose name ends with _{notebook_id}.""" + """Scan outputs/{user_id}/ and outputs/ for a directory whose name ends with _{notebook_id}.""" safe_id = self._notebook_id.replace("/", "_").replace("\\", "_")[:128] suffix = f"_{safe_id}" outputs_dir = self._project_root / "outputs" if not outputs_dir.exists(): return None + + # First try user-specific directory + safe_uid = _sanitize_user_id(self._user_id) + user_dir = outputs_dir / safe_uid + if user_dir.exists(): + try: + for d in user_dir.iterdir(): + if d.is_dir() and d.name.endswith(suffix): + return d + except Exception: + pass + + # Fallback: scan ALL user directories under outputs/ for the notebook + # This handles cases where user_id changed (e.g., UUID -> email) try: - for d in outputs_dir.iterdir(): - if d.is_dir() and d.name.endswith(suffix): - return d + for user_candidate in outputs_dir.iterdir(): + if not user_candidate.is_dir() or user_candidate == user_dir: + continue + for d in user_candidate.iterdir(): + if d.is_dir() and d.name.endswith(suffix): + return d except Exception: pass return None diff --git a/fastapi_app/qwen_tts_manager.py b/fastapi_app/qwen_tts_manager.py new file mode 100644 index 0000000..d08eaa6 --- /dev/null +++ b/fastapi_app/qwen_tts_manager.py @@ -0,0 +1,164 @@ +""" +Qwen3-TTS Manager +懒加载 + 自动卸载 +""" +import os +import time +import threading +import torch + +REPO_ID = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" # CustomVoice 支持预定义说话人 +IDLE_TIMEOUT = int(os.getenv("TTS_IDLE_TIMEOUT", "300")) + +_model = None +_device = None +_last_used = None +_lock = threading.Lock() +_unload_timer = None + + +def _pick_device(): + """选择最优GPU""" + if not torch.cuda.is_available(): + return "cpu" + + try: + import subprocess + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index,memory.free", "--format=csv,noheader,nounits"], + capture_output=True, text=True, timeout=5 + ) + lines = [l.strip() for l in result.stdout.strip().split("\n") if l.strip()] + if not lines: + return "cuda:0" + + gpu_mem = [] + for line in lines: + parts = line.split(",") + if len(parts) == 2: + idx, mem = parts[0].strip(), parts[1].strip() + gpu_mem.append((int(idx), int(mem))) + + if gpu_mem: + best_gpu = max(gpu_mem, key=lambda x: x[1]) + return f"cuda:{best_gpu[0]}" + except Exception: + pass + + return "cuda:0" + + +def _schedule_unload(): + """调度自动卸载""" + global _unload_timer + if _unload_timer is not None: + _unload_timer.cancel() + + def _unload(): + global _model, _device, _last_used + with _lock: + if _last_used and (time.time() - _last_used >= IDLE_TIMEOUT): + print(f"[TTS] 空闲 {IDLE_TIMEOUT}s,卸载模型") + _model = None + _device = None + _last_used = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + _unload_timer = threading.Timer(IDLE_TIMEOUT, _unload) + _unload_timer.daemon = True + _unload_timer.start() + + +def _load_model(): + """懒加载模型""" + global _model, _device, _last_used + + if _model is not None: + _last_used = time.time() + _schedule_unload() + return _model, _device + + print(f"[TTS] 加载 Qwen3-TTS 模型: {REPO_ID}") + + try: + from qwen_tts import Qwen3TTSModel + except ImportError as e: + raise RuntimeError(f"qwen_tts 未安装: {e}\n运行: pip install qwen-tts") + + _device = _pick_device() + dtype = torch.bfloat16 if _device.startswith("cuda") else torch.float32 + print(f"[TTS] 使用设备: {_device}, dtype: {dtype}") + + try: + _model = Qwen3TTSModel.from_pretrained( + REPO_ID, + device_map=_device, + dtype=dtype, + ) + except Exception as e: + print(f"[TTS] 加载失败: {e}") + raise + + _last_used = time.time() + _schedule_unload() + print(f"[TTS] 模型加载完成") + + return _model, _device + + +def generate_speech(text: str, voice_name: str = "vivian", language: str = "Chinese") -> bytes: + """ + 生成语音,返回 WAV 音频字节 + + Args: + text: 文本内容 + voice_name: 说话人名称(默认 vivian) + 支持: aiden, dylan, eric, ono_anna, ryan, serena, sohee, uncle_fu, vivian + language: 语言(Chinese/English,默认 Chinese) + + Returns: + WAV 格式音频字节 + """ + with _lock: + model, device = _load_model() + + # 确保 voice_name 是小写且在支持列表中 + supported_speakers = ['aiden', 'dylan', 'eric', 'ono_anna', 'ryan', 'serena', 'sohee', 'uncle_fu', 'vivian'] + voice_name = voice_name.lower() + if voice_name not in supported_speakers: + voice_name = 'vivian' + + # 播客风格指令 + instruct = "用自然、亲切的播客主播语气讲述,语速适中,富有感染力" if language == "Chinese" else "Speak in a natural, friendly podcast host tone with moderate pace and engaging delivery" + + # 生成音频 + wavs, sr = model.generate_custom_voice( + text=text.strip(), + language=language, + speaker=voice_name, + instruct=instruct, + ) + + # 转换为 WAV 字节 + import io + import soundfile as sf + buf = io.BytesIO() + sf.write(buf, wavs[0], sr, format="wav") + buf.seek(0) + return buf.read() + + +def is_available() -> bool: + """检查 Qwen3-TTS 是否可用""" + try: + from qwen_tts import Qwen3TTSModel + return True + except ImportError: + return False + + +def check_and_download_model(): + """启动时检查并下载 Qwen3-TTS 模型""" + print(f"[TTS] 检查模型: {REPO_ID}") + print(f"[TTS] 首次使用时将自动从 HuggingFace 下载或使用本地缓存") diff --git a/fastapi_app/routers/auth.py b/fastapi_app/routers/auth.py new file mode 100644 index 0000000..15d2003 --- /dev/null +++ b/fastapi_app/routers/auth.py @@ -0,0 +1,30 @@ +""" +Auth configuration endpoint. +""" +import os +from fastapi import APIRouter +from typing import Dict, Any, Optional +from fastapi_app.dependencies.auth import get_supabase_client + +router = APIRouter(prefix="/auth", tags=["Auth"]) + + +@router.get("/config") +async def get_auth_config() -> Dict[str, Any]: + """ + Check if Supabase is configured and return config. + Frontend calls this to determine if auth is available and get credentials. + """ + supabase = get_supabase_client() + configured = supabase is not None + + result: Dict[str, Any] = { + "supabaseConfigured": configured + } + + if configured: + # Return Supabase URL and anon key for frontend to use + result["supabaseUrl"] = os.getenv("SUPABASE_URL") + result["supabaseAnonKey"] = os.getenv("SUPABASE_ANON_KEY") + + return result diff --git a/fastapi_app/routers/data_insight.py b/fastapi_app/routers/data_insight.py new file mode 100644 index 0000000..b7ac8e9 --- /dev/null +++ b/fastapi_app/routers/data_insight.py @@ -0,0 +1,247 @@ +""" +Data Insight Discovery API +Multi-dataset insight analysis using DM framework. +""" +import json +import tempfile +from pathlib import Path +from typing import List, Optional, Dict, Any +from fastapi import APIRouter, Form, HTTPException, UploadFile, File +from fastapi.responses import FileResponse +from pydantic import BaseModel + +import pandas as pd + +from workflow_engine.logger import get_logger +from fastapi_app.services.data_insight_service import DataInsightService + +log = get_logger(__name__) +router = APIRouter(prefix="/data_insight", tags=["data_insight"]) + + +# ==================== Pydantic Models ==================== +class DataInsightResponse(BaseModel): + """Response model for data insight analysis""" + status: str + synthesized_insights: List[str] + raw_insights: List[str] + summary: str + detailed_appendix: Dict[str, Any] = {} + result_path: str = "" + error: Optional[str] = None + + +class ErrorResponse(BaseModel): + """Standard error response""" + error: str + code: str = "INTERNAL_ERROR" + details: Optional[Dict] = None + + +# ==================== API Endpoints ==================== +@router.post( + "/analyze", + response_model=DataInsightResponse, + responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}}, +) +async def analyze_datasets( + chat_api_url: str = Form(..., description="LLM API URL"), + api_key: str = Form(..., description="LLM API key"), + model: str = Form("deepseek-v3.2", description="Model name"), + output_mode: str = Form("concise", description="Output mode: concise or detailed"), + language: str = Form("en", description="Language preference"), + files: List[UploadFile] = File(..., description="Data files (CSV, Excel)"), + analysis_goal: Optional[str] = Form(None, description="Custom analysis goal"), + email: Optional[str] = Form(None, description="User email"), +): + """ + Analyze multiple datasets and discover insights. + + Accepts CSV, Excel files. + Returns synthesized insights and summary. + """ + try: + # Validate inputs + if not files: + raise HTTPException(status_code=400, detail="No files provided") + + if not api_key or not chat_api_url: + raise HTTPException(status_code=400, detail="API key and URL required") + + # Call service + service = DataInsightService() + result = await service.analyze_datasets( + chat_api_url=chat_api_url, + api_key=api_key, + model=model, + output_mode=output_mode, + analysis_goal=analysis_goal, + language=language, + email=email, + files=files, + ) + + # Check for errors + if result.get("status") == "error": + raise HTTPException( + status_code=500, + detail=result.get("error", "Analysis failed") + ) + + # Convert raw_insights from dict to string if needed + raw_insights = result.get("raw_insights", []) + if raw_insights and isinstance(raw_insights[0], dict): + # Convert dict format to string representation + result["raw_insights"] = [ + f"[{item.get('source', 'unknown')}] {item.get('insight', str(item))}" + for item in raw_insights + ] + + return DataInsightResponse(**result) + + except HTTPException: + raise + except Exception as e: + log.error(f"Unexpected error in analyze_datasets: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +# ==================== Helper Functions ==================== +def generate_markdown_report( + synthesized_insights: List[str], + raw_insights: List[str], + summary: str, + detailed_appendix: Dict[str, Any], + language: str = "en" +) -> str: + """ + Generate a markdown report from analysis results. + + Args: + synthesized_insights: List of synthesized insights + raw_insights: List of raw insights from individual agents + summary: Overall summary + detailed_appendix: Detailed appendix data + language: Language preference (en/zh) + + Returns: + Markdown formatted report content + """ + lang = language.lower() + is_zh = lang == "zh" + + # Headers + title = "📊 Data Insight Report" if is_zh else "📊 Data Insight Report" + summary_header = "📝 Summary" if is_zh else "📝 Summary" + insights_header = "💡 Key Insights" if is_zh else "💡 Key Insights" + raw_header = "📋 Raw Analysis" if is_zh else "📋 Raw Analysis" + appendix_header = "📎 Detailed Appendix" if is_zh else "📎 Detailed Appendix" + footer = f"*Generated on {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}*" if is_zh else f"*Generated on {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}*" + + # Build report + report_lines = [ + f"# {title}", + "", + f"## {summary_header}", + "", + summary, + "", + f"## {insights_header}", + "" + ] + + # Add synthesized insights + for i, insight in enumerate(synthesized_insights, 1): + if is_zh: + report_lines.append(f"### Insight {i}") + else: + report_lines.append(f"### Insight {i}") + report_lines.append("") + report_lines.append(insight) + report_lines.append("") + + # Add raw insights if available + if raw_insights: + report_lines.append(f"## {raw_header}") + report_lines.append("") + for i, insight in enumerate(raw_insights, 1): + report_lines.append(f"**{i}.** {insight}") + report_lines.append("") + + # Add detailed appendix if available + if detailed_appendix: + report_lines.append(f"## {appendix_header}") + report_lines.append("") + for key, value in detailed_appendix.items(): + report_lines.append(f"### {key}") + report_lines.append("") + if isinstance(value, dict): + for k, v in value.items(): + report_lines.append(f"- **{k}:** {v}") + elif isinstance(value, list): + for item in value: + report_lines.append(f"- {item}") + else: + report_lines.append(str(value)) + report_lines.append("") + + # Add footer + report_lines.append("---") + report_lines.append("") + report_lines.append(footer) + + return "\n".join(report_lines) + + +# ==================== New API Endpoints ==================== +@router.post( + "/generate_report", + responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}}, +) +async def generate_report( + synthesized_insights: str = Form(..., description="JSON string of synthesized insights"), + raw_insights: str = Form(..., description="JSON string of raw insights"), + summary: str = Form(..., description="Analysis summary"), + detailed_appendix: str = Form("{}", description="JSON string of detailed appendix"), + language: str = Form("en", description="Language preference"), +): + """ + Generate a markdown report from analysis results. + """ + try: + # Parse JSON strings + synthesized = json.loads(synthesized_insights) if synthesized_insights else [] + raw = json.loads(raw_insights) if raw_insights else [] + appendix = json.loads(detailed_appendix) if detailed_appendix else {} + + # Generate markdown report + report_content = generate_markdown_report( + synthesized_insights=synthesized, + raw_insights=raw, + summary=summary, + detailed_appendix=appendix, + language=language + ) + + # Save to temporary file + temp_dir = Path(tempfile.gettempdir()) / "data_insight_reports" + temp_dir.mkdir(parents=True, exist_ok=True) + + report_filename = f"insight_report_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.md" + report_path = temp_dir / report_filename + report_path.write_text(report_content, encoding='utf-8') + + log.info(f"Generated markdown report: {report_path}") + + return FileResponse( + path=str(report_path), + filename=report_filename, + media_type='text/markdown' + ) + + except json.JSONDecodeError as e: + log.error(f"JSON decode error: {e}") + raise HTTPException(status_code=400, detail="Invalid JSON format in request") + except Exception as e: + log.error(f"Error generating report: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/fastapi_app/routers/files.py b/fastapi_app/routers/files.py index 0a05167..f9994d1 100644 --- a/fastapi_app/routers/files.py +++ b/fastapi_app/routers/files.py @@ -12,7 +12,7 @@ from fastapi_app.dependencies import get_current_user, get_optional_user, AuthUser from fastapi_app.utils import _from_outputs_url -from dataflow_agent.utils import get_project_root +from workflow_engine.utils import get_project_root router = APIRouter(prefix="/files", tags=["files"]) diff --git a/fastapi_app/routers/kb.py b/fastapi_app/routers/kb.py index 41c9b64..7be07eb 100644 --- a/fastapi_app/routers/kb.py +++ b/fastapi_app/routers/kb.py @@ -12,14 +12,11 @@ import fitz # PyMuPDF -from dataflow_agent.state import IntelligentQARequest, IntelligentQAState, KBPodcastRequest, KBPodcastState, KBMindMapRequest, KBMindMapState -from dataflow_agent.workflow.wf_intelligent_qa import create_intelligent_qa_graph -from dataflow_agent.workflow.wf_kb_podcast import create_kb_podcast_graph -from dataflow_agent.workflow.wf_kb_mindmap import create_kb_mindmap_graph -from dataflow_agent.toolkits.ragtool.vector_store_tool import process_knowledge_base_files, VectorStoreManager -from dataflow_agent.utils import get_project_root -from dataflow_agent.logger import get_logger -from dataflow_agent.workflow import run_workflow +from workflow_engine.state import IntelligentQARequest, IntelligentQAState, KBPodcastRequest, KBPodcastState, KBMindMapRequest, KBMindMapState +from workflow_engine.toolkits.ragtool.vector_store_tool import process_knowledge_base_files, VectorStoreManager +from workflow_engine.utils import get_project_root +from workflow_engine.logger import get_logger +from workflow_engine.workflow import run_workflow log = get_logger(__name__) from fastapi_app.config import settings @@ -27,11 +24,11 @@ from fastapi_app.utils import _from_outputs_url, _to_outputs_url from fastapi_app.workflow_adapters.wa_paper2ppt import _init_state_from_request from fastapi_app.dependencies.auth import get_supabase_admin_client -from fastapi_app.notebook_paths import NotebookPaths, get_notebook_paths +from fastapi_app.notebook_paths import NotebookPaths, get_notebook_paths, _sanitize_user_id from fastapi_app.source_manager import SourceManager from fastapi_app.services.fast_research_service import fast_research_search from fastapi_app.services.deep_research_report_service import generate_report_from_search -from dataflow_agent.toolkits.research_tools import fetch_page_text +from workflow_engine.toolkits.research_tools import fetch_page_text router = APIRouter(prefix="/kb", tags=["Knowledge Base"]) @@ -45,18 +42,20 @@ def _notebook_dir(email: str, notebook_id: Optional[str]) -> Path: - """User + notebook scoped dir under kb_data. Use raw email on disk so StaticFiles can resolve URL-decoded path.""" + """User + notebook scoped dir under kb_data. Email is sanitized for filesystem safety.""" root = get_project_root() - base = root / KB_BASE_DIR / (email or "default") + safe_email = _sanitize_user_id(email) if email else "default" + base = root / KB_BASE_DIR / safe_email if notebook_id: return base / notebook_id.replace("/", "_").replace("\\", "_")[:128] return base / "_shared" def _outputs_dir(email: str, notebook_id: Optional[str], subdir: str) -> Path: - """User + notebook scoped output dir. Use raw email on disk for StaticFiles resolution.""" + """User + notebook scoped output dir. Email is sanitized for filesystem safety.""" root = get_project_root() - base = root / OUTPUTS_BASE / (email or "default") + safe_email = _sanitize_user_id(email) if email else "default" + base = root / OUTPUTS_BASE / safe_email if notebook_id: base = base / notebook_id.replace("/", "_").replace("\\", "_")[:128] else: @@ -119,7 +118,7 @@ def _text_to_pdf(text: str, output_path: str) -> None: doc.close() -ALLOWED_EXTENSIONS = {".pdf", ".docx", ".pptx", ".png", ".jpg", ".jpeg", ".mp4", ".md"} +ALLOWED_EXTENSIONS = {".pdf", ".docx", ".pptx", ".png", ".jpg", ".jpeg", ".mp4", ".md", ".csv", ".txt", ".db", ".json", ".jsonl", ".xlsx", ".xls", ".parquet", ".ndjson"} IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg"} DOC_EXTENSIONS = {".pdf", ".docx", ".doc", ".pptx", ".ppt", ".md", ".markdown"} @@ -161,7 +160,8 @@ def _find_mineru_stem_dir( # 2) Legacy: kb_mineru/{email}/{notebook_id}/ safe_nb = (notebook_id or "_shared").replace("/", "_").replace("\\", "_")[:128] - mineru_base = project_root / "outputs" / "kb_mineru" / (email or "default") / safe_nb + safe_email = _sanitize_user_id(email) if email else "default" + mineru_base = project_root / "outputs" / "kb_mineru" / safe_email / safe_nb if not mineru_base.exists(): return None @@ -329,14 +329,14 @@ def _append_images_to_pptx(pptx_path: Path, image_paths: List[Path]) -> None: @router.post("/upload") async def upload_kb_file( - file: UploadFile = File(...), + files: List[UploadFile] = File(...), email: str = Form(...), user_id: str = Form(...), notebook_id: Optional[str] = Form(None), notebook_title: Optional[str] = Form(None), ): """ - Upload a file to the notebook's knowledge base directory. + Upload multiple files to the notebook's knowledge base directory. New layout: outputs/{title}_{id}/sources/{stem}/original/ Fallback: also writes to legacy kb_data path for backward compat. """ @@ -345,77 +345,89 @@ async def upload_kb_file( if not notebook_id: raise HTTPException(status_code=400, detail="notebook_id is required for per-notebook storage") - file_ext = Path(file.filename).suffix.lower() - if file_ext not in ALLOWED_EXTENSIONS: - raise HTTPException( - status_code=400, - detail=f"Unsupported file type: {file_ext}. Allowed: {', '.join(ALLOWED_EXTENSIONS)}" - ) + uploaded_files = [] + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + mgr = SourceManager(paths) - try: - filename = file.filename or f"unnamed_{user_id}" - filename = os.path.basename(filename) + for file in files: + file_ext = Path(file.filename).suffix.lower() + if file_ext not in ALLOWED_EXTENSIONS: + log.warning(f"Skipping unsupported file: {file.filename}") + continue - # --- New notebook-centric layout --- - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) - mgr = SourceManager(paths) + try: + filename = file.filename or f"unnamed_{user_id}" + filename = os.path.basename(filename) - # Save uploaded bytes to a temp location first, then import - tmp_dir = paths.root / "_tmp" - tmp_dir.mkdir(parents=True, exist_ok=True) - tmp_path = tmp_dir / filename - with open(tmp_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) + # Save uploaded bytes to a temp location first, then import + tmp_dir = paths.root / "_tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + tmp_path = tmp_dir / filename + with open(tmp_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) - source_info = await mgr.import_file(tmp_path, filename) + source_info = await mgr.import_file(tmp_path, filename) - # Clean up temp - try: - tmp_path.unlink(missing_ok=True) - except Exception: - pass + # Clean up temp + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass - # Build static URL from the original path in new layout - project_root = get_project_root() - rel = source_info.original_path.relative_to(project_root) - static_path = "/" + rel.as_posix() - - # --- Also write to legacy path for backward compat --- - legacy_dir = _notebook_dir(email, notebook_id) - legacy_dir.mkdir(parents=True, exist_ok=True) - legacy_path = legacy_dir / filename - if not legacy_path.exists(): - shutil.copy2(str(source_info.original_path), str(legacy_path)) + # Build static URL from the original path in new layout + project_root = get_project_root() + rel = source_info.original_path.relative_to(project_root) + static_path = "/" + rel.as_posix() - # Auto-embed using new vector_store path - embedded = False - try: - vector_base = str(paths.vector_store_dir) - mineru_base = str(paths.source_mineru_dir(filename)) - file_list = [{"path": str(source_info.original_path)}] - await process_knowledge_base_files( - file_list=file_list, - base_dir=vector_base, - mineru_output_base=mineru_base, - ) - embedded = True - log.info("[upload] auto-embedding done: %s", filename) - except Exception as emb_err: - log.warning("[upload] auto-embedding failed for %s: %s", filename, emb_err) + # --- Also write to legacy path for backward compat --- + legacy_dir = _notebook_dir(email, notebook_id) + legacy_dir.mkdir(parents=True, exist_ok=True) + legacy_path = legacy_dir / filename + if not legacy_path.exists(): + shutil.copy2(str(source_info.original_path), str(legacy_path)) - return { - "success": True, - "filename": filename, - "file_size": os.path.getsize(source_info.original_path), - "storage_path": str(source_info.original_path), - "static_url": static_path, - "file_type": file.content_type, - "embedded": embedded, - } + # Auto-embed using new vector_store path + embedded = False + try: + vector_base = str(paths.vector_store_dir) + mineru_base = str(paths.source_mineru_dir(filename)) + file_list = [{"path": str(source_info.original_path)}] + # 使用本地embedding服务 + local_embedding_url = os.getenv("EMBEDDING_API_URL", "http://127.0.0.1:26210/v1/embeddings") + await process_knowledge_base_files( + file_list=file_list, + base_dir=vector_base, + mineru_output_base=mineru_base, + api_url=local_embedding_url, + ) + embedded = True + log.info("[upload] auto-embedding done: %s", filename) + except Exception as emb_err: + log.warning("[upload] auto-embedding failed for %s: %s", filename, emb_err) + + uploaded_files.append({ + "filename": filename, + "file_size": os.path.getsize(source_info.original_path), + "storage_path": str(source_info.original_path), + "static_url": static_path, + "file_type": file.content_type, + "embedded": embedded, + }) - except Exception as e: - print(f"Error uploading file: {e}") - raise HTTPException(status_code=500, detail=str(e)) + except Exception as e: + log.error(f"Error uploading file {file.filename}: {e}") + uploaded_files.append({ + "filename": file.filename, + "error": str(e), + "success": False + }) + + return { + "success": True, + "files": uploaded_files, + "total_uploaded": len([f for f in uploaded_files if "error" not in f]), + "total_failed": len([f for f in uploaded_files if "error" in f]), + } def _sanitize_md_filename(title: str, prefix: str = "doc") -> str: @@ -465,7 +477,7 @@ async def add_text_source( raise HTTPException(status_code=400, detail="content is required") # New layout - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) mgr = SourceManager(paths) source_info = await mgr.import_text(content, title) @@ -528,7 +540,7 @@ async def import_url_as_source( title = "网页" # New layout - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) mgr = SourceManager(paths) source_info = await mgr.import_url(url, text, title) @@ -589,7 +601,8 @@ def _vector_store_base_dir(email: Optional[str], notebook_id: Optional[str]) -> if not email: base = root / "outputs" / "kb_data" / "vector_store_main" else: - base = root / "outputs" / "kb_data" / (email or "default") + safe_email = _sanitize_user_id(email) + base = root / "outputs" / "kb_data" / safe_email if notebook_id: safe_nb = notebook_id.replace("/", "_").replace("\\", "_")[:128] base = base / safe_nb / "vector_store" @@ -612,10 +625,17 @@ async def chat_with_kb( """ Intelligent QA Chat. 若传 email/notebook_id 且该 notebook 已建索引,会优先用 RAG 检索片段作为上下文。 """ + log.info(f"[chat_with_kb] === Request received ===") + log.info(f"[chat_with_kb] files (raw): {files}") + log.info(f"[chat_with_kb] email (raw): {email}") + log.info(f"[chat_with_kb] notebook_id (raw): {notebook_id}") + log.info(f"[chat_with_kb] query length: {len(query)}") + try: # Normalize file paths (web path -> local absolute path) project_root = get_project_root() local_files = [] + for f in files: # remove leading /outputs/ if present, or just join # Web path: /outputs/kb_data/... @@ -623,21 +643,56 @@ async def chat_with_kb( p = project_root / clean_path if p.exists(): local_files.append(str(p)) + log.info(f"[chat_with_kb] ✓ Found file: {f} -> {p}") else: # Try raw path p_raw = Path(f) if p_raw.exists(): local_files.append(str(p_raw)) - + log.info(f"[chat_with_kb] ✓ Found file (raw): {f} -> {p_raw}") + else: + log.warning(f"[chat_with_kb] ✗ File not found: {f}") + + log.info(f"[chat_with_kb] Resolved local_files: {local_files}") + if not local_files: # Just return empty answer or handle logic - pass + log.warning("[chat_with_kb] No valid local files found, will rely on RAG only") - vector_store_base_dir = _vector_store_base_dir(email, notebook_id) + # Use new notebook paths system instead of legacy _vector_store_base_dir + vector_store_base_dir = None + if email and notebook_id: + try: + # Find notebook directory by scanning outputs/{email}/ + project_root = get_project_root() + email_dir = project_root / "outputs" / email.replace("@", "_at_") + + if email_dir.exists(): + # Look for directories matching pattern *_{notebook_id} + for nb_dir in email_dir.iterdir(): + if nb_dir.is_dir() and nb_dir.name.endswith(f"_{notebook_id}"): + vector_store_path = nb_dir / "vector_store" + if vector_store_path.exists(): + vector_store_base_dir = str(vector_store_path) + log.info(f"[chat_with_kb] Found vector store: {vector_store_base_dir}") + break + + if not vector_store_base_dir: + log.warning(f"[chat_with_kb] No vector_store found for email={email}, notebook_id={notebook_id}") + except Exception as e: + log.warning(f"[chat_with_kb] Failed to search for vector store: {e}") + + # Fallback to legacy system if needed + if not vector_store_base_dir: + vector_store_base_dir = _vector_store_base_dir(email, notebook_id) + if vector_store_base_dir: + log.info(f"[chat_with_kb] Using legacy paths system, vector_store_base_dir: {vector_store_base_dir}") + else: + log.warning(f"[chat_with_kb] vector_store_base_dir not found in either new or legacy system") # Construct Request req = IntelligentQARequest( - files=local_files, + file_ids=local_files, query=query, history=history, vector_store_base_dir=vector_store_base_dir, @@ -800,8 +855,9 @@ async def list_outputs( email: Optional[str] = None, user_id: Optional[str] = None, notebook_id: Optional[str] = None, + notebook_title: Optional[str] = None, ) -> Dict[str, Any]: - """List generated outputs (ppt/mindmap/podcast) for user. Prefer DB, fallback to disk scan.""" + """List generated outputs (ppt/mindmap/podcast/drawio) for user. Prefer DB, fallback to disk scan.""" sb = get_supabase_admin_client() project_root = get_project_root() files: List[Dict[str, Any]] = [] @@ -829,34 +885,38 @@ async def list_outputs( }) except Exception as e: log.warning("list_outputs from db failed: %s", e) - if not files and email: - for email_part in (email.replace("@", "%40"), email): - base = project_root / OUTPUTS_BASE / email_part - if not base.exists(): - continue - # If notebook_id given, only scan that notebook's subdir - if notebook_id: - base = base / notebook_id.replace("/", "_").replace("\\", "_")[:128] - if not base.exists(): - break - it = [base] if base.is_dir() and notebook_id else sorted(base.iterdir(), key=lambda x: x.stat().st_mtime if x.is_dir() else 0, reverse=True)[:50] - for d in it: - if not d.is_dir(): - continue - ts = d.name - for f in d.rglob("*"): - if f.suffix.lower() in (".pdf", ".pptx", ".mmd", ".mermaid", ".wav", ".mp3", ".m4a"): - rel = str(f.relative_to(project_root)) - out_type = "podcast" if f.suffix.lower() in (".wav", ".mp3", ".m4a") else ("mindmap" if f.suffix.lower() in (".mmd", ".mermaid") else "ppt") - files.append({ - "id": f"disk_{ts}_{f.name}", - "output_type": out_type, - "file_name": f.name, - "download_url": _to_outputs_url(rel), - "created_at": d.stat().st_mtime, - }) - break - break + # Disk fallback: scan notebook-centric directory layout + if not files and notebook_id: + _FEATURE_EXT_MAP = { + "ppt": {".pdf", ".pptx"}, + "mindmap": {".mmd", ".mermaid"}, + "podcast": {".wav", ".mp3", ".m4a"}, + "drawio": {".drawio"}, + } + try: + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + nb_root = paths.root + if nb_root.exists(): + for feature, exts in _FEATURE_EXT_MAP.items(): + feature_dir = nb_root / feature + if not feature_dir.exists(): + continue + for ts_dir in feature_dir.iterdir(): + if not ts_dir.is_dir(): + continue + for f in ts_dir.iterdir(): + if f.suffix.lower() in exts: + rel = str(f.relative_to(project_root)) + files.append({ + "id": f"disk_{ts_dir.name}_{f.name}", + "output_type": feature, + "file_name": f.name, + "download_url": _to_outputs_url(rel), + "created_at": ts_dir.stat().st_mtime, + }) + break + except Exception as e: + log.warning("list_outputs disk scan failed: %s", e) return {"success": True, "files": files} @@ -940,10 +1000,10 @@ def _save_output_record( # ---------- 1.3 笔记本(目录)与后端联动 ---------- def _notebooks_local_path(user_id: str) -> Path: root = get_project_root() - base = root / "outputs" / "kb_data" / "_notebooks" + safe_id = _sanitize_user_id(user_id) + base = root / "outputs" / safe_id base.mkdir(parents=True, exist_ok=True) - safe_id = (user_id or "default").replace("/", "_").replace("\\", "_")[:64] - return base / f"{safe_id}.json" + return base / "_notebooks.json" def _list_notebooks_local(user_id: str) -> List[Dict[str, Any]]: @@ -980,9 +1040,9 @@ def _create_notebook_local(user_id: str, name: str, description: str = "") -> Di return new_nb -# 不做用户管理时使用的默认用户,数据从 outputs 取 -DEFAULT_USER_ID = "default" -DEFAULT_EMAIL = "default" +# 不做用户管理时使用的默认用户,数据从 outputs/local 取 +DEFAULT_USER_ID = "local" +DEFAULT_EMAIL = "local" @router.get("/notebooks") @@ -990,31 +1050,10 @@ async def list_notebooks( email: Optional[str] = None, user_id: Optional[str] = None, ) -> Dict[str, Any]: - """List notebooks. No user_id => use default (data from outputs).""" - uid = (user_id or "").strip() or DEFAULT_USER_ID - sb = get_supabase_admin_client() - if sb: - try: - q = sb.table("knowledge_bases").select("id,name,description,created_at,updated_at").eq("user_id", uid) - r = q.order("updated_at", desc=True).execute() - rows = (r.data or []) if hasattr(r, "data") else [] - if rows: - from collections import Counter - nb_ids = [row["id"] for row in rows] - try: - fr = sb.table("knowledge_base_files").select("kb_id").in_("kb_id", nb_ids).execute() - file_rows = (fr.data or []) if hasattr(fr, "data") else [] - counts = Counter(f.get("kb_id") for f in file_rows if f.get("kb_id")) - except Exception as e: - log.warning("notebooks file count failed: %s", e) - counts = {} - for row in rows: - row["sources"] = counts.get(row["id"], 0) - return {"success": True, "notebooks": rows} - except Exception as e: - log.warning("list_notebooks failed: %s", e) - return {"success": True, "notebooks": []} - rows = _list_notebooks_local(uid) + """List notebooks from local filesystem.""" + # Prefer email for directory naming (more readable than UUID) + dir_id = (email or "").strip() or (user_id or "").strip() or DEFAULT_USER_ID + rows = _list_notebooks_local(dir_id) email_for_path = (email or "").strip() or DEFAULT_EMAIL for row in rows: nb_id = row.get("id") @@ -1022,7 +1061,7 @@ async def list_notebooks( count = 0 # New layout count try: - paths = get_notebook_paths(nb_id, row.get("name", ""), uid) + paths = get_notebook_paths(nb_id, row.get("name", ""), dir_id) if paths.sources_dir.exists(): count += sum(1 for d in paths.sources_dir.iterdir() if d.is_dir() and (d / "original").exists()) except Exception: @@ -1061,7 +1100,7 @@ async def list_notebook_files( # --- 1) Read from new layout: outputs/{title}_{id}/sources/ --- try: - paths = get_notebook_paths(notebook_id, notebook_title or "", uid) + paths = get_notebook_paths(notebook_id, notebook_title or "", em or uid) sources_dir = paths.sources_dir if sources_dir.exists(): for src_dir in sorted(sources_dir.iterdir()): @@ -1377,7 +1416,7 @@ async def import_link_sources( raise HTTPException(status_code=400, detail="notebook_id and email are required") # New layout - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) mgr = SourceManager(paths) # Legacy compat @@ -1457,32 +1496,24 @@ async def create_notebook( name: str = Body(..., embed=True), description: Optional[str] = Body(None, embed=True), user_id: str = Body(..., embed=True), + email: Optional[str] = Body(None, embed=True), ) -> Dict[str, Any]: - """Create a notebook. Uses Supabase if configured, else local JSON file. - Also creates the new outputs/{title}_{id}/sources/ directory.""" - sb = get_supabase_admin_client() - nb_data = None - if sb: - try: - ins = sb.table("knowledge_bases").insert({"user_id": user_id, "name": name, "description": description or ""}).execute() - data = (ins.data or []) if hasattr(ins, "data") else [] - nb_data = data[0] if data else None - except Exception as e: - log.warning("create_notebook failed: %s", e) - return {"success": False, "message": str(e)} - else: - try: - nb_data = _create_notebook_local(user_id, name, description or "") - except HTTPException: - raise - except Exception as e: - log.warning("create_notebook local failed: %s", e) - return {"success": False, "message": str(e)} + """Create a notebook using local JSON file. + Also creates the new outputs/{user_id}/{title}_{id}/sources/ directory.""" + # Prefer email for directory naming (more readable than UUID) + dir_id = (email or "").strip() or user_id + try: + nb_data = _create_notebook_local(dir_id, name, description or "") + except HTTPException: + raise + except Exception as e: + log.warning("create_notebook local failed: %s", e) + return {"success": False, "message": str(e)} # Create new directory structure if nb_data and nb_data.get("id"): try: - paths = get_notebook_paths(nb_data["id"], name, user_id) + paths = get_notebook_paths(nb_data["id"], name, dir_id) paths.sources_dir.mkdir(parents=True, exist_ok=True) log.info("[create_notebook] created dir: %s", paths.root) except Exception as e: @@ -1562,7 +1593,7 @@ async def generate_ppt_from_kb( project_root = get_project_root() # New layout: outputs/{title}_{id}/ppt/{ts}/ if notebook_id: - nb_paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + nb_paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) output_dir = nb_paths.feature_output_dir("ppt", ts) else: output_dir = _outputs_dir(email, notebook_id, f"{ts}_ppt") @@ -1686,7 +1717,8 @@ async def generate_ppt_from_kb( embed_api_url = embed_api_url.rstrip("/") + "/embeddings" project_root = get_project_root() safe_nb = (notebook_id or "_shared").replace("/", "_").replace("\\", "_")[:128] - mineru_output_base = project_root / "outputs" / "kb_mineru" / (email or "default") / safe_nb + safe_email = _sanitize_user_id(email) if email else "default" + mineru_output_base = project_root / "outputs" / "kb_mineru" / safe_email / safe_nb mineru_output_base.mkdir(parents=True, exist_ok=True) files_for_embed = [{"path": str(p), "description": ""} for p in doc_paths] @@ -1847,62 +1879,115 @@ async def generate_deep_research_report( search_api_key: Optional[str] = Body(None, embed=True), search_engine: Optional[str] = Body("google", embed=True), search_top_k: int = Body(10, embed=True), + # 新增:DeepResearch 完整模式配置 + use_full_deep_research: bool = Body(True, embed=True), # 默认使用完整的阿里DeepResearch + max_iterations: int = Body(50, embed=True), # DeepResearch最大迭代次数 + serper_api_key: Optional[str] = Body(None, embed=True), # Serper API密钥 + jina_api_key: Optional[str] = Body(None, embed=True), # Jina API密钥 ) -> Dict[str, Any]: """ - Deep Research 报告生成:search → 结果拼成上下文 → 调 LLM 生成长报告 → 保存为 .md 并可选引入。 - 不生成 PDF,.md 可预览、可嵌入。 + Deep Research 报告生成(默认使用完整版阿里DeepResearch): + - use_full_deep_research=True: 完整版(阿里DeepResearch多轮ReAct推理,深度)【默认】 + - use_full_deep_research=False: 简化版(搜索 + LLM总结,快速) """ try: if not isinstance(page_count, int) or page_count < 1 or page_count > 50: raise HTTPException(status_code=400, detail="page_count must be an integer between 1 and 50") ts = int(time.time()) project_root = get_project_root() + # New layout: outputs/{title}_{id}/deep_research/{ts}/ if notebook_id: - dr_paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + dr_paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) output_dir = dr_paths.feature_output_dir("deep_research", ts) else: output_dir = _outputs_dir(email, notebook_id, f"{ts}_deep_research") output_dir.mkdir(parents=True, exist_ok=True) topic = topic.strip() - search_top_k = max(1, min(20, search_top_k)) - log.info( - "[generate-deep-research-report] start: topic=%r, search_top_k=%s, provider=%s, model=%s, language=%s", - topic[:150], search_top_k, search_provider, model, language, - ) - # 1) 搜索:用 topic 做 Fast Research,拿到 top_k 条结果 - sources = fast_research_search( - topic, - top_k=search_top_k, - search_provider=search_provider or "serper", - search_api_key=search_api_key, - search_engine=search_engine or "google", - ) - log.info("[generate-deep-research-report] search 完成: 共 %s 条来源", len(sources)) - search_context = "" - if sources: - search_context = "\n\n".join( - f"[{i+1}] 标题: {s.get('title', '')}\n链接: {s.get('link', '')}\n摘要: {s.get('snippet', '')}" - for i, s in enumerate(sources) + # ============================================================================ + # 模式选择:完整DeepResearch vs 简化版 + # ============================================================================ + + if use_full_deep_research: + # 使用完整的阿里DeepResearch(多轮ReAct推理) + log.info("[generate-deep-research-report] 使用完整DeepResearch模式: topic=%r, max_iterations=%s", topic[:150], max_iterations) + + # 如果没有传递 serper_api_key,尝试使用 search_api_key 作为回退 + final_serper_key = serper_api_key or search_api_key + + log.info("[generate-deep-research-report] API配置: serper_api_key=%s, search_api_key=%s, final_serper_key=%s", + "***" if serper_api_key else "None", + "***" if search_api_key else "None", + "***" if final_serper_key else "None") + + # 运行完整DeepResearch(直接传递参数,不依赖环境变量) + from fastapi_app.services.deep_research_integration import DeepResearchIntegration + + integration = DeepResearchIntegration( + model_name=model, + api_base=api_url, + api_key=api_key, + max_iterations=max_iterations, + serper_key=final_serper_key, + jina_keys=jina_api_key, ) - log.info("[generate-deep-research-report] search_context 拼接完成: len=%s", len(search_context)) + result = await integration.run_research( + query=topic, + max_iterations=max_iterations + ) + + if not result["success"]: + raise HTTPException(status_code=500, detail=result.get("error", "DeepResearch failed")) + + # 格式化为Markdown + report = integration.format_result_as_markdown(result) + report_title = f"DeepResearch: {topic[:50]}" + + log.info("[generate-deep-research-report] 完整DeepResearch完成: iterations=%s, sources=%s", + result.get("iterations", 0), len(result.get("sources", []))) + else: - log.warning("[generate-deep-research-report] no search results, LLM will generate from topic only") + # 使用简化版(搜索 + LLM总结) + search_top_k = max(1, min(20, search_top_k)) + log.info( + "[generate-deep-research-report] 使用简化版模式: topic=%r, search_top_k=%s, provider=%s, model=%s", + topic[:150], search_top_k, search_provider, model, + ) - # 2) LLM:根据 topic + search_context 生成一篇长报告(返回标题 + 正文) - report_title, report = generate_report_from_search( - topic=topic, - search_context=search_context, - api_url=api_url, - api_key=api_key, - model=model, - language=language, - ) - if not (report or "").strip(): - raise HTTPException(status_code=500, detail="LLM did not return report content") - log.info("[generate-deep-research-report] LLM 报告生成完成: title=%r, report_len=%s", report_title, len(report)) + # 1) 搜索:用 topic 做 Fast Research,拿到 top_k 条结果 + sources = fast_research_search( + topic, + top_k=search_top_k, + search_provider=search_provider or "serper", + search_api_key=search_api_key or serper_api_key, + search_engine=search_engine or "google", + ) + log.info("[generate-deep-research-report] search 完成: 共 %s 条来源", len(sources)) + + search_context = "" + if sources: + search_context = "\n\n".join( + f"[{i+1}] 标题: {s.get('title', '')}\n链接: {s.get('link', '')}\n摘要: {s.get('snippet', '')}" + for i, s in enumerate(sources) + ) + log.info("[generate-deep-research-report] search_context 拼接完成: len=%s", len(search_context)) + else: + log.warning("[generate-deep-research-report] no search results, LLM will generate from topic only") + + # 2) LLM:根据 topic + search_context 生成一篇长报告(返回标题 + 正文) + report_title, report = generate_report_from_search( + topic=topic, + search_context=search_context, + api_url=api_url, + api_key=api_key, + model=model, + language=language, + ) + if not (report or "").strip(): + raise HTTPException(status_code=500, detail="LLM did not return report content") + log.info("[generate-deep-research-report] 简化版报告生成完成: title=%r, report_len=%s", report_title, len(report)) # 3) 来源名:固定前缀 [report] + LLM 给的标题,保存为 .md safe_title = re.sub(r'[/\\:*?"<>|]', "", (report_title or "").strip()) or "report" @@ -2006,7 +2091,7 @@ async def generate_podcast_from_kb( ts = int(time.time()) # New layout: outputs/{title}_{id}/podcast/{ts}/ if notebook_id: - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) output_dir = paths.feature_output_dir("podcast", ts) else: output_dir = _outputs_dir(email, notebook_id, f"{ts}_podcast") @@ -2094,9 +2179,13 @@ async def generate_podcast_from_kb( else: local_file_paths = [str(filtered_paths[0])] + # Get vector store base directory + vector_store_base_dir = _vector_store_base_dir(email, notebook_id) + # Prepare request podcast_req = KBPodcastRequest( - files=local_file_paths, + file_ids=local_file_paths, + vector_store_base_dir=vector_store_base_dir, chat_api_url=api_url, api_key=api_key, model=model, @@ -2195,7 +2284,7 @@ async def generate_mindmap_from_kb( ts = int(time.time()) # New layout: outputs/{title}_{id}/mindmap/{ts}/ if notebook_id: - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) output_dir = paths.feature_output_dir("mindmap", ts) else: output_dir = _outputs_dir(email, notebook_id, f"{ts}_mindmap_input") @@ -2232,9 +2321,13 @@ async def generate_mindmap_from_kb( if not local_file_paths: raise HTTPException(status_code=400, detail="No valid files provided") + # Get vector store base directory + vector_store_base_dir = _vector_store_base_dir(email, notebook_id) + # Prepare request mindmap_req = KBMindMapRequest( - files=local_file_paths, + file_ids=local_file_paths, + vector_store_base_dir=vector_store_base_dir, chat_api_url=api_url, api_key=api_key, model=model, @@ -2352,182 +2445,16 @@ async def generate_drawio_from_kb( ): """ 从知识库选中文件生成 DrawIO 图表。 - 优先从 MinerU 提取的 figure 图片走 SAM3 分割生成 drawio(缓存到 sources/{stem}/sam3/), - 没有 figure 图片时 fallback 到文本模式 LLM 生成。 - """ - try: - log.info("[generate-drawio] 收到 file_paths: %s", file_paths) - project_root = get_project_root() - - # --- SAM3 图片模式:有 notebook 且能找到 figure 图片时自动走 SAM3 --- - if notebook_id: - from fastapi_app.services.paper2drawio_service import Paper2DrawioService - - paths = get_notebook_paths(notebook_id, notebook_title or "", email) - mgr = SourceManager(paths) - ts = int(time.time()) - output_dir = paths.feature_output_dir("drawio", ts) - output_dir.mkdir(parents=True, exist_ok=True) - - # 收集所有 source 中的 figure 图片 - figure_images = _collect_figure_images(mgr, file_paths, project_root) - if not figure_images: - log.warning("[generate-drawio] SAM3 模式但未找到 figure 图片,回退到文本模式") - else: - service = Paper2DrawioService() - all_xmls = [] - for stem, img_path in figure_images: - sam3_cache = str(mgr.ensure_sam3_dir(stem)) - result = await service.generate_diagram_from_image( - image_path=str(img_path), - chat_api_url=api_url, - api_key=api_key, - model=model, - language=language, - email=email, - sam3_cache_dir=sam3_cache, - output_dir=str(output_dir), - ) - if result.get("success") and result.get("xml_content"): - all_xmls.append(result["xml_content"]) - log.info("[generate-drawio] SAM3 成功: stem=%s", stem) - else: - log.warning("[generate-drawio] SAM3 失败: stem=%s err=%s", stem, result.get("error")) - - if all_xmls: - xml_content = all_xmls[0] # 目前取第一个成功的 - drawio_path = output_dir / f"diagram_{ts}.drawio" - drawio_path.write_text(xml_content, encoding="utf-8") - download_url = _to_outputs_url(str(drawio_path)) - _save_output_record( - email=email, user_id=user_id, notebook_id=notebook_id, - output_type="drawio", file_name=drawio_path.name, - file_path=str(drawio_path), result_path=str(output_dir), - download_url=download_url, - ) - return { - "success": True, - "xml_content": xml_content, - "file_path": download_url, - "error": None, - "output_file_id": f"kb_drawio_{ts}", - } - # SAM3 全部失败,fall through 到文本模式 - log.warning("[generate-drawio] SAM3 全部失败,回退到文本模式") - - # --- 文本模式(原有逻辑) --- - url_sources = [] - local_file_paths = [] - for f in (file_paths or []): - ps = (f or "").strip() - if ps.startswith("http://") or ps.startswith("https://"): - url_sources.append(ps) - else: - local_path = _resolve_local_path(ps) - if not local_path.exists() or not local_path.is_file(): - log.warning("[generate-drawio] 文件不存在: 原始=%s 解析后=%s", ps, local_path) - raise HTTPException(status_code=404, detail=f"File not found: {ps}") - local_file_paths.append(str(local_path)) - - parts = [] - for i, url in enumerate(url_sources): - # 优先用引入时已存的 .md,不重新爬 - local_md = _resolve_link_to_local_md(email, notebook_id, url) - if local_md is not None: - try: - content = local_md.read_text(encoding="utf-8", errors="replace") - if content.strip(): - parts.append(f"来源{i + 1}:\n{content}") - log.info("[generate-drawio] 使用已存 .md: %s", local_md.name) - continue - except Exception as e: - log.warning("[generate-drawio] 读取已存 .md 失败 %s: %s", local_md, e) - try: - content = fetch_page_text(url, max_chars=100000) - if content and not content.startswith("["): - parts.append(f"来源{i + 1}:\n{content}") - else: - parts.append(f"来源{i + 1}:\n[抓取失败或无正文]") - except Exception as e: - log.warning("[generate-drawio] 抓取 URL 失败 %s: %s", url[:60], e) - parts.append(f"来源{i + 1}:\n[抓取失败: {e}]") - if local_file_paths: - local_text = _extract_text_from_files(local_file_paths) - if local_text.strip(): - parts.append(local_text) - text_content = "\n\n".join(parts) if parts else "" - if not text_content.strip(): - raise HTTPException( - status_code=400, - detail="No text from selected sources (URL fetch failed or files empty). Check link or choose local files.", - ) - from fastapi_app.services.paper2drawio_service import Paper2DrawioService - - service = Paper2DrawioService() - result = await service.generate_diagram( - request=None, - chat_api_url=api_url, - api_key=api_key, - model=model, - enable_vlm_validation=False, - vlm_model=getattr(settings, "PAPER2DRAWIO_VLM_MODEL", "deepseek-v3.2"), - vlm_validation_max_retries=3, - input_type="TEXT", - diagram_type=diagram_type, - diagram_style=diagram_style, - language=language, - email=email, - file=None, - text_content=text_content, - ) - - if not result.get("success") or not result.get("xml_content"): - return { - "success": False, - "xml_content": "", - "file_path": "", - "error": result.get("error") or "Failed to generate diagram", - "output_file_id": None, - } - - xml_content = result["xml_content"] - ts = int(time.time()) - # New layout: outputs/{title}_{id}/drawio/{ts}/ - if notebook_id: - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) - output_dir = paths.feature_output_dir("drawio", ts) - else: - output_dir = project_root / OUTPUTS_BASE / (email or "default") / "_shared" / "drawio" - output_dir.mkdir(parents=True, exist_ok=True) - drawio_path = output_dir / f"diagram_{ts}.drawio" - drawio_path.write_text(xml_content, encoding="utf-8") - download_url = _to_outputs_url(str(drawio_path)) + 注意:此功能正在重构中,暂时不可用。 + 优先:思维导图生成(/generate-mindmap)、播客生成(/generate-podcast) + """ + raise HTTPException( + status_code=501, + detail="DrawIO 生成功能正在重构中,暂时不可用。请使用思维导图生成功能(/api/v1/kb/generate-mindmap)作为替代。" + ) - _save_output_record( - email=email, - user_id=user_id, - notebook_id=notebook_id, - output_type="drawio", - file_name=drawio_path.name, - file_path=str(drawio_path), - result_path=str(output_dir), - download_url=download_url, - ) - return { - "success": True, - "xml_content": xml_content, - "file_path": download_url, - "error": None, - "output_file_id": f"kb_drawio_{ts}", - } - except HTTPException: - raise - except Exception as e: - import traceback - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(e)) @router.post("/save-mindmap") @@ -2624,7 +2551,7 @@ async def generate_flashcards( ts = int(time.time()) flashcard_set_id = f"flashcard_{ts}" if notebook_id: - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) output_dir = paths.feature_output_dir("flashcard", ts) else: output_dir = _outputs_dir(email, notebook_id, flashcard_set_id) @@ -2711,7 +2638,7 @@ async def generate_quiz( ts = int(time.time()) quiz_id = f"quiz_{ts}" if notebook_id: - paths = get_notebook_paths(notebook_id, notebook_title or "", user_id) + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) output_dir = paths.feature_output_dir("quiz", ts) else: output_dir = _outputs_dir(email, notebook_id, quiz_id) @@ -2742,3 +2669,330 @@ async def generate_quiz( except Exception as e: log.exception("[generate-quiz] failed") raise HTTPException(status_code=500, detail=str(e)) + + +# ===================== Flashcard / Quiz 读取端点 ===================== + +@router.get("/list-flashcard-sets") +async def list_flashcard_sets( + notebook_id: str, + notebook_title: Optional[str] = None, + user_id: Optional[str] = None, + email: Optional[str] = None, +): + """列出某 notebook 下所有已保存的闪卡集合(按时间倒序)""" + try: + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + flashcard_root = paths.root / "flashcard" + sets = [] + if flashcard_root.exists(): + for ts_dir in flashcard_root.iterdir(): + if not ts_dir.is_dir(): + continue + json_file = ts_dir / "flashcards.json" + if not json_file.exists(): + continue + try: + data = json.loads(json_file.read_text(encoding="utf-8")) + sets.append({ + "set_id": ts_dir.name, + "id": data.get("id", ""), + "created_at": data.get("created_at", ""), + "total_count": data.get("total_count", 0), + "source_files": data.get("source_files", []), + }) + except Exception: + continue + sets.sort(key=lambda x: x["set_id"], reverse=True) + return {"success": True, "sets": sets} + except Exception as e: + log.exception("[list-flashcard-sets] failed") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/list-quiz-sets") +async def list_quiz_sets( + notebook_id: str, + notebook_title: Optional[str] = None, + user_id: Optional[str] = None, + email: Optional[str] = None, +): + """列出某 notebook 下所有已保存的测验集合(按时间倒序)""" + try: + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + quiz_root = paths.root / "quiz" + sets = [] + if quiz_root.exists(): + for ts_dir in quiz_root.iterdir(): + if not ts_dir.is_dir(): + continue + json_file = ts_dir / "quiz.json" + if not json_file.exists(): + continue + try: + data = json.loads(json_file.read_text(encoding="utf-8")) + sets.append({ + "set_id": ts_dir.name, + "id": data.get("id", ""), + "created_at": data.get("created_at", ""), + "total_count": data.get("total_count", 0), + "source_files": data.get("source_files", []), + }) + except Exception: + continue + sets.sort(key=lambda x: x["set_id"], reverse=True) + return {"success": True, "sets": sets} + except Exception as e: + log.exception("[list-quiz-sets] failed") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/get-flashcard-set") +async def get_flashcard_set( + notebook_id: str, + set_id: str, + notebook_title: Optional[str] = None, + user_id: Optional[str] = None, + email: Optional[str] = None, +): + """读取指定闪卡集合的完整数据""" + try: + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + json_file = paths.root / "flashcard" / set_id / "flashcards.json" + if not json_file.exists(): + raise HTTPException(status_code=404, detail="Flashcard set not found") + data = json.loads(json_file.read_text(encoding="utf-8")) + return {"success": True, **data} + except HTTPException: + raise + except Exception as e: + log.exception("[get-flashcard-set] failed") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/get-quiz-set") +async def get_quiz_set( + notebook_id: str, + set_id: str, + notebook_title: Optional[str] = None, + user_id: Optional[str] = None, + email: Optional[str] = None, +): + """读取指定测验集合的完整数据""" + try: + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + json_file = paths.root / "quiz" / set_id / "quiz.json" + if not json_file.exists(): + raise HTTPException(status_code=404, detail="Quiz set not found") + data = json.loads(json_file.read_text(encoding="utf-8")) + return {"success": True, **data} + except HTTPException: + raise + except Exception as e: + log.exception("[get-quiz-set] failed") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================ +# DeepResearch Integration +# ============================================================================ + +@router.post("/deep-research") +async def run_deep_research( + query: str = Body(..., embed=True), + notebook_id: str = Body(..., embed=True), + notebook_title: Optional[str] = Body(None, embed=True), + user_id: Optional[str] = Body(None, embed=True), + email: Optional[str] = Body(None, embed=True), + max_iterations: int = Body(50, embed=True), +): + """ + 运行 DeepResearch 深度研究并将结果保存为 source + + Args: + query: 研究问题 + notebook_id: Notebook ID + notebook_title: Notebook 标题 + user_id: 用户 ID + email: 用户邮箱 + max_iterations: 最大迭代次数 + + Returns: + { + "success": bool, + "query": str, + "answer": str, + "source_info": {...}, # 保存的 source 信息 + "error": str (optional) + } + """ + try: + from fastapi_app.services.deep_research_integration import DeepResearchIntegration + + log.info(f"[deep-research] 开始深度研究: {query}") + + # 1. 运行完整的 DeepResearch + integration = DeepResearchIntegration() + result = await integration.run_research( + query=query, + max_iterations=max_iterations + ) + + if not result["success"]: + return result + + # 2. 将结果保存为 source + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + mgr = SourceManager(paths) + + # 格式化为 Markdown + markdown_content = integration.format_result_as_markdown(result) + + # 保存为文本 source + source_info = await mgr.import_text( + text=markdown_content, + title=f"DeepResearch: {query[:50]}" + ) + + log.info(f"[deep-research] 已保存结果: {source_info.original_path}") + + # 3. 自动 embedding + try: + vector_base = str(paths.vector_store_dir) + file_list = [{"path": str(source_info.original_path)}] + await process_knowledge_base_files( + file_list=file_list, + vector_base=vector_base, + email=email or "default", + user_id=user_id or "default", + notebook_id=notebook_id, + ) + log.info(f"[deep-research] 已完成 embedding") + except Exception as e: + log.warning(f"[deep-research] Embedding 失败: {e}") + + return { + "success": True, + "query": query, + "answer": result["answer"], + "source_info": { + "file_type": source_info.file_type, + "original_path": str(source_info.original_path), + "markdown_path": str(source_info.markdown_path) if source_info.markdown_path else None, + }, + "sources_count": len(result.get("sources", [])), + } + + except Exception as e: + log.exception("[deep-research] 执行失败") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================ +# Search & Add Integration +# ============================================================================ + +@router.post("/search-and-add") +async def search_and_add( + query: str = Body(..., embed=True), + notebook_id: str = Body(..., embed=True), + notebook_title: Optional[str] = Body(None, embed=True), + user_id: Optional[str] = Body(None, embed=True), + email: Optional[str] = Body(None, embed=True), + top_k: int = Body(10, embed=True), + search_provider: str = Body("serper", embed=True), + search_api_key: Optional[str] = Body(None, embed=True), +): + """ + 搜索并爬取 Top K 结果,保存为 source + + Args: + query: 搜索查询 + notebook_id: Notebook ID + notebook_title: Notebook 标题 + user_id: 用户 ID + email: 用户邮箱 + top_k: 返回前 K 个结果 + search_provider: 搜索引擎提供商 + search_api_key: 搜索 API 密钥 + + Returns: + { + "success": bool, + "query": str, + "sources_count": int, + "crawled_count": int, + "source_info": {...} + } + """ + try: + from fastapi_app.services.search_and_add_service import SearchAndAddService + + log.info(f"[search-and-add] 开始搜索: {query}, top_k={top_k}") + + # 1. 搜索并爬取 + service = SearchAndAddService() + result = await service.search_and_crawl( + query=query, + top_k=top_k, + search_provider=search_provider, + search_api_key=search_api_key, + ) + + if not result["success"]: + return result + + sources = result["sources"] + if not sources: + return { + "success": False, + "query": query, + "error": "未找到搜索结果" + } + + # 2. 将所有结果合并为一个 Markdown 文档 + paths = get_notebook_paths(notebook_id, notebook_title or "", email or user_id) + mgr = SourceManager(paths) + + markdown_content = service.format_sources_as_markdown(sources) + + # 保存为文本 source + source_info = await mgr.import_text( + text=markdown_content, + title=f"Search: {query[:50]}" + ) + + log.info(f"[search-and-add] 已保存 {len(sources)} 个结果: {source_info.original_path}") + + # 3. 自动 embedding + try: + vector_base = str(paths.vector_store_dir) + file_list = [{"path": str(source_info.original_path)}] + await process_knowledge_base_files( + file_list=file_list, + vector_base=vector_base, + email=email or "default", + user_id=user_id or "default", + notebook_id=notebook_id, + ) + log.info(f"[search-and-add] 已完成 embedding") + except Exception as e: + log.warning(f"[search-and-add] Embedding 失败: {e}") + + crawled_count = sum(1 for s in sources if s["crawl_success"]) + + return { + "success": True, + "query": query, + "sources_count": len(sources), + "crawled_count": crawled_count, + "source_info": { + "file_type": source_info.file_type, + "original_path": str(source_info.original_path), + "markdown_path": str(source_info.markdown_path) if source_info.markdown_path else None, + } + } + + except Exception as e: + log.exception("[search-and-add] 执行失败") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/fastapi_app/routers/kb_embedding.py b/fastapi_app/routers/kb_embedding.py index 7c38c20..00f0aa7 100644 --- a/fastapi_app/routers/kb_embedding.py +++ b/fastapi_app/routers/kb_embedding.py @@ -1,13 +1,14 @@ from fastapi import APIRouter, HTTPException, Body from typing import List, Dict, Optional, Any -import os from pathlib import Path -from dataflow_agent.toolkits.ragtool.vector_store_tool import process_knowledge_base_files, VectorStoreManager -from dataflow_agent.utils import get_project_root -from fastapi_app.config import settings +from workflow_engine.toolkits.ragtool.vector_store_tool import process_knowledge_base_files, VectorStoreManager +from workflow_engine.utils import get_project_root from fastapi_app.utils import _to_outputs_url from fastapi_app.dependencies.auth import get_supabase_client -from fastapi_app.notebook_paths import get_notebook_paths +from fastapi_app.notebook_paths import get_notebook_paths, _sanitize_user_id +from workflow_engine.logger import get_logger + +log = get_logger(__name__) router = APIRouter(prefix="/kb", tags=["Knowledge Base Embedding"]) @@ -17,7 +18,8 @@ def _vector_store_dir(email: Optional[str], notebook_id: Optional[str]): project_root = get_project_root() if not email: return project_root / "outputs" / "kb_data" / "vector_store_main" - base = project_root / "outputs" / "kb_data" / (email or "default") + safe_email = _sanitize_user_id(email) + base = project_root / "outputs" / "kb_data" / safe_email if notebook_id: safe_nb = notebook_id.replace("/", "_").replace("\\", "_")[:128] return base / safe_nb / "vector_store" @@ -71,7 +73,7 @@ def _write_manifest_ids_to_supabase(manifest: Dict[str, Any]) -> None: if any(x in err_msg for x in ("kb_file_id", "pgrst204", "schema", "column", "could not find")): pass else: - print(f"[kb_embedding] Supabase writeback failed: {e}") + log.warning(f"Supabase writeback failed: {e}") @router.post("/embedding") async def create_embedding( @@ -82,7 +84,7 @@ async def create_embedding( api_url: Optional[str] = Body(None, embed=True), api_key: Optional[str] = Body(None, embed=True), model_name: Optional[str] = Body(None, embed=True), - multimodal_model: Optional[str] = Body(settings.KB_EMBEDDING_MODEL, embed=True), + multimodal_model: Optional[str] = Body(None, embed=True), image_model: Optional[str] = Body(None, embed=True), video_model: Optional[str] = Body(None, embed=True), ): @@ -117,7 +119,7 @@ async def create_embedding( except Exception: pass else: - print(f"Warning: File not found locally: {local_path}") + log.warning(f"File not found locally: {local_path}") if not process_list: return {"success": False, "message": "No valid files found to process."} @@ -130,7 +132,8 @@ async def create_embedding( else: vector_store_dir = _vector_store_dir(user_email, notebook_id) safe_nb = (notebook_id or "_shared").replace("/", "_").replace("\\", "_")[:128] - mineru_output_base = project_root / "outputs" / "kb_mineru" / (user_email or "default") / safe_nb + safe_email = _sanitize_user_id(user_email) if user_email else "default" + mineru_output_base = project_root / "outputs" / "kb_mineru" / safe_email / safe_nb vector_store_dir.mkdir(parents=True, exist_ok=True) mineru_output_base.mkdir(parents=True, exist_ok=True) @@ -162,7 +165,7 @@ async def create_embedding( try: _write_manifest_ids_to_supabase(manifest) except Exception as e: - print(f"[kb_embedding] writeback error: {e}") + log.warning(f"Supabase writeback error: {e}") raise HTTPException( status_code=422, detail=f"向量入库失败: {first_err}" @@ -171,17 +174,18 @@ async def create_embedding( try: _write_manifest_ids_to_supabase(manifest) except Exception as e: - print(f"[kb_embedding] writeback error: {e}") + log.warning(f"Supabase writeback error: {e}") return { "success": True, "message": f"Successfully processed {len(process_list)} files", "manifest": manifest } + except HTTPException: + raise except Exception as e: - import traceback - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(e)) + log.error(f"向量入库失败: {type(e).__name__}: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="向量入库失败,请检查文件格式或联系管理员") @router.get("/list") async def list_kb_files( @@ -242,9 +246,8 @@ async def delete_vector( manager.remove_file(file_id) return {"success": True, "message": "向量已删除"} except Exception as e: - import traceback - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(e)) + log.error(f"删除向量失败: {type(e).__name__}: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail="删除向量失败") @router.post("/search") diff --git a/fastapi_app/routers/paper2drawio.py b/fastapi_app/routers/paper2drawio.py index d012def..3e11aca 100644 --- a/fastapi_app/routers/paper2drawio.py +++ b/fastapi_app/routers/paper2drawio.py @@ -8,7 +8,7 @@ from fastapi import APIRouter, File, Form, Request, UploadFile from pydantic import BaseModel -from dataflow_agent.logger import get_logger +from workflow_engine.logger import get_logger log = get_logger(__name__) router = APIRouter(prefix="/paper2drawio", tags=["paper2drawio"]) diff --git a/fastapi_app/routers/paper2ppt.py b/fastapi_app/routers/paper2ppt.py index 9f5854f..17d6865 100644 --- a/fastapi_app/routers/paper2ppt.py +++ b/fastapi_app/routers/paper2ppt.py @@ -11,7 +11,7 @@ from fastapi_app.schemas import ErrorResponse, FullPipelineRequest, OutlineRefineRequest, PageContentRequest, PPTGenerationRequest from fastapi_app.services.paper2ppt_service import Paper2PPTService -from dataflow_agent.utils.version_manager import ImageVersionManager +from workflow_engine.utils.version_manager import ImageVersionManager from fastapi_app.utils import _to_outputs_url router = APIRouter(tags=["paper2ppt"]) diff --git a/fastapi_app/schemas.py b/fastapi_app/schemas.py index 0094474..009a1ad 100644 --- a/fastapi_app/schemas.py +++ b/fastapi_app/schemas.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Literal -from dataflow_agent.utils import get_project_root +from workflow_engine.utils import get_project_root from pydantic import BaseModel, Field from fastapi_app.config import settings diff --git a/fastapi_app/services/data_insight_service.py b/fastapi_app/services/data_insight_service.py new file mode 100644 index 0000000..930f45b --- /dev/null +++ b/fastapi_app/services/data_insight_service.py @@ -0,0 +1,82 @@ +""" +Data Insight Service +Handles file upload and calls adapter. +""" +import time +from pathlib import Path +from typing import Any, Dict, List, Optional +from fastapi import UploadFile + +from workflow_engine.logger import get_logger +from workflow_engine.utils import get_project_root +from fastapi_app.workflow_adapters.wa_data_insight import DataInsightAdapter + +log = get_logger(__name__) + + +class DataInsightService: + """Data insight analysis service""" + + def _create_upload_dir(self, email: Optional[str]) -> Path: + """Create directory for uploaded files.""" + ts = int(time.time()) + root = get_project_root() + upload_dir = root / "outputs" / "data_insights" / (email or "default") / f"{ts}_upload" + upload_dir.mkdir(parents=True, exist_ok=True) + return upload_dir + + async def analyze_datasets( + self, + chat_api_url: str, + api_key: str, + model: str, + output_mode: str, + analysis_goal: Optional[str], + language: str, + email: Optional[str], + files: List[UploadFile], + ) -> Dict[str, Any]: + """ + Execute insight analysis workflow. + + Args: + chat_api_url: LLM API URL + api_key: LLM API key + model: Model name + output_mode: "concise" or "detailed" + analysis_goal: Optional custom goal + language: Language preference + email: User email + files: Uploaded data files + + Returns: + Analysis results dict + """ + # Save uploaded files + upload_dir = self._create_upload_dir(email) + file_paths = [] + + for file in files: + file_path = upload_dir / (file.filename or f"file_{len(file_paths)}.csv") + content = await file.read() + file_path.write_bytes(content) + file_paths.append(str(file_path)) + log.info(f"Uploaded: {file.filename}") + + # Build request dict + request_data = { + "file_ids": file_paths, + "model": model, + "api_key": api_key, + "chat_api_url": chat_api_url, + "output_mode": output_mode, + "analysis_goal": analysis_goal, + "language": language, + "email": email + } + + # Call adapter (NOT workflow directly) + adapter = DataInsightAdapter() + result = await adapter.execute(request_data) + + return result diff --git a/fastapi_app/services/deep_research_integration.py b/fastapi_app/services/deep_research_integration.py new file mode 100644 index 0000000..a3a5185 --- /dev/null +++ b/fastapi_app/services/deep_research_integration.py @@ -0,0 +1,359 @@ +""" +完整集成阿里 DeepResearch 到 Open-NotebookLM +使用内部 deep_research 模块 +""" +import os +import json +import asyncio +from typing import List, Dict, Any, Optional +from pathlib import Path + +from workflow_engine.logger import get_logger +from fastapi_app.deep_research.react_agent import MultiTurnReactAgent +from qwen_agent.llm.schema import Message + +log = get_logger(__name__) + + +class DeepResearchIntegration: + """完整集成阿里 DeepResearch""" + + def __init__( + self, + model_name: Optional[str] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + max_iterations: Optional[int] = None, + serper_key: Optional[str] = None, + jina_keys: Optional[str] = None, + dashscope_key: Optional[str] = None, + sandbox_endpoints: Optional[str] = None, + ): + # 配置参数(优先使用传入参数,其次使用环境变量) + self.model_name = model_name or os.getenv("DEEP_RESEARCH_MODEL", "qwen-plus") + self.api_base = api_base or os.getenv("DEEP_RESEARCH_API_BASE", "http://127.0.0.1:6001") + self.api_key = api_key or os.getenv("DEEP_RESEARCH_API_KEY", "EMPTY") + self.max_iterations = max_iterations or int(os.getenv("DEEP_RESEARCH_MAX_ITERATIONS", "50")) + + # 工具配置(优先使用传入参数,其次使用环境变量) + self.serper_key = serper_key or os.getenv("SERPER_KEY_ID", os.getenv("SERPER_API_KEY", "")) + self.jina_keys = jina_keys or os.getenv("JINA_API_KEYS", "") + self.dashscope_key = dashscope_key or os.getenv("DASHSCOPE_API_KEY", "") + self.sandbox_endpoints = sandbox_endpoints or os.getenv("SANDBOX_FUSION_ENDPOINT", "") + + # 调试日志 + log.info(f"[DeepResearchIntegration] 初始化配置:") + log.info(f" - model_name: {self.model_name}") + log.info(f" - api_base: {self.api_base}") + log.info(f" - serper_key: {'***' if self.serper_key else 'None'} (length: {len(self.serper_key) if self.serper_key else 0})") + log.info(f" - jina_keys: {'***' if self.jina_keys else 'None'}") + log.info(f" - max_iterations: {self.max_iterations}") + + async def run_research( + self, + query: str, + max_iterations: Optional[int] = None, + temperature: float = 0.85, + presence_penalty: float = 1.1, + ) -> Dict[str, Any]: + """ + 运行完整的 DeepResearch 推理 + + Args: + query: 研究问题 + max_iterations: 最大迭代次数 + temperature: 采样温度 + presence_penalty: 存在惩罚 + + Returns: + { + "success": bool, + "query": str, + "answer": str, + "messages": List[Dict], + "sources": List[Dict], + "termination": str, + "iterations": int + } + """ + log.info(f"[DeepResearch] 开始研究: {query}") + + try: + # 检查必要的配置 + if not self.serper_key: + raise ValueError("SERPER_KEY_ID 或 SERPER_API_KEY 未配置") + + # ⚠️ 重要:在创建 Agent 之前设置环境变量 + # 因为工具在模块加载时读取环境变量 + import os + os.environ["SERPER_KEY_ID"] = self.serper_key + if self.jina_keys: + os.environ["JINA_API_KEYS"] = self.jina_keys + if self.dashscope_key: + os.environ["DASHSCOPE_API_KEY"] = self.dashscope_key + + # ⚠️ Visit 工具的 call_server 需要这三个环境变量来调用 LLM 总结网页内容 + os.environ["API_KEY"] = self.api_key + os.environ["API_BASE"] = self.api_base + os.environ["SUMMARY_MODEL_NAME"] = self.model_name + + # 配置 LLM + llm_config = { + "model": self.model_name, + "api_base": self.api_base, + "api_key": self.api_key, + "generate_cfg": { + "temperature": temperature, + "top_p": 0.95, + "presence_penalty": presence_penalty, + "max_tokens": 10000, + } + } + + # 创建 Agent + agent = MultiTurnReactAgent(llm=llm_config) + + # 设置最大迭代次数 + max_iter = max_iterations or self.max_iterations + + # 运行推理(在线程池中运行,避免阻塞) + result = await asyncio.to_thread( + self._run_agent_sync, + agent, + query, + max_iter + ) + + log.info(f"[DeepResearch] 完成研究,迭代次数: {result['iterations']}") + + return result + + except ImportError as e: + log.error(f"[DeepResearch] 导入失败: {e}") + return { + "success": False, + "query": query, + "answer": "", + "messages": [], + "sources": [], + "error": f"DeepResearch 模块导入失败: {str(e)}", + "termination": "import_error" + } + except Exception as e: + log.error(f"[DeepResearch] 执行失败: {e}") + return { + "success": False, + "query": query, + "answer": "", + "messages": [], + "sources": [], + "error": str(e), + "termination": "error" + } + + def _run_agent_sync(self, agent, query, max_iterations): + """同步运行 Agent(在线程池中调用)""" + try: + # 构造 data 参数,符合原始 _run 方法的要求 + # 传递完整的 API base URL 而不是端口号 + data = { + "item": { + "question": query, + "answer": "" # 我们不知道答案,留空 + }, + "planning_port": self.api_base # 传递完整的 API base URL + } + + log.info(f"[DeepResearch] 调用 Agent,API base: {self.api_base}, model: {self.model_name}") + + # 调用 Agent 的 _run 方法 + result = agent._run( + data=data, + model=self.model_name + ) + + # 解析结果 + messages = result.get("messages", []) + answer = result.get("prediction", "") + termination = result.get("termination", "unknown") + + # 提取来源 + sources = self._extract_sources_from_messages(messages) + + return { + "success": True, + "query": query, + "answer": answer, + "messages": messages, + "sources": sources, + "termination": termination, + "iterations": len([m for m in messages if m.get("role") == "assistant"]) + } + + except Exception as e: + log.error(f"[DeepResearch] Agent 运行失败: {e}") + import traceback + traceback.print_exc() + raise + + def _extract_answer(self, messages: List) -> str: + """从消息列表中提取最终答案""" + for msg in reversed(messages): + content = str(msg.content) if hasattr(msg, 'content') else str(msg) + + # 查找 标签 + if "" in content and "" in content: + import re + match = re.search(r'(.*?)', content, re.DOTALL) + if match: + return match.group(1).strip() + + # 如果没有 answer 标签,返回最后一条 assistant 消息 + if hasattr(msg, 'role') and msg.role == "assistant": + # 移除 think 标签 + import re + content = re.sub(r'.*?', '', content, flags=re.DOTALL) + content = re.sub(r'.*?', '', content, flags=re.DOTALL) + return content.strip() + + return "未生成答案" + + def _extract_sources(self, messages: List) -> List[Dict]: + """从消息中提取引用的来源(兼容 Message 对象)""" + sources = [] + seen_urls = set() + + for msg in messages: + content = str(msg.content) if hasattr(msg, 'content') else str(msg) + + # 提取 tool_response 中的 URL + if "" in content: + import re + # 提取所有 URL + urls = re.findall(r'https?://[^\s<>"\']+', content) + for url in urls: + if url not in seen_urls: + seen_urls.add(url) + sources.append({ + "url": url, + "type": "web_search" + }) + + return sources + + def _extract_sources_from_messages(self, messages: List[Dict]) -> List[Dict]: + """从消息字典列表中提取引用的来源""" + sources = [] + seen_urls = set() + + for msg in messages: + content = msg.get("content", "") + + # 提取 tool_response 中的 URL + if "" in content: + import re + # 提取所有 URL + urls = re.findall(r'https?://[^\s<>"\']+', content) + for url in urls: + if url not in seen_urls: + seen_urls.add(url) + sources.append({ + "url": url, + "type": "web_search" + }) + + return sources + + def _determine_termination(self, messages: List, max_iterations: int) -> str: + """判断终止原因""" + if not messages: + return "no_messages" + + last_msg = messages[-1] + content = str(last_msg.content) if hasattr(last_msg, 'content') else str(last_msg) + + if "" in content and "" in content: + return "answer" + + assistant_count = len([m for m in messages if hasattr(m, 'role') and m.role == "assistant"]) + if assistant_count >= max_iterations: + return "max_iterations" + + return "unknown" + + def _message_to_dict(self, msg) -> Dict: + """将 Message 对象转换为字典""" + if hasattr(msg, 'role') and hasattr(msg, 'content'): + return { + "role": msg.role, + "content": msg.content + } + else: + return { + "role": "unknown", + "content": str(msg) + } + + def format_result_as_markdown(self, result: Dict[str, Any]) -> str: + """将研究结果格式化为 Markdown""" + md_lines = [ + f"# Deep Research: {result['query']}", + "", + "## Research Answer", + "", + result.get("answer", "No answer generated."), + "", + ] + + # 添加来源 + sources = result.get("sources", []) + if sources: + md_lines.extend([ + "## Sources", + "", + ]) + for i, source in enumerate(sources, 1): + url = source.get("url", "") + md_lines.append(f"{i}. [{url}]({url})") + md_lines.append("") + + # 添加元数据 + md_lines.extend([ + "---", + "", + "**Metadata:**", + f"- Termination: {result.get('termination', 'unknown')}", + f"- Iterations: {result.get('iterations', 0)}", + f"- Total messages: {len(result.get('messages', []))}", + "", + ]) + + return "\n".join(md_lines) + + async def check_dependencies(self) -> Dict[str, bool]: + """检查依赖是否满足""" + checks = { + "serper_key": bool(self.serper_key), + "jina_keys": bool(self.jina_keys), + "qwen_agent": False, + } + + try: + import qwen_agent + checks["qwen_agent"] = True + except ImportError: + pass + + return checks + + def get_config_info(self) -> Dict[str, Any]: + """获取配置信息""" + return { + "model": self.model_name, + "api_base": self.api_base, + "max_iterations": self.max_iterations, + "serper_configured": bool(self.serper_key), + "jina_configured": bool(self.jina_keys), + "dashscope_configured": bool(self.dashscope_key), + "sandbox_configured": bool(self.sandbox_endpoints), + } diff --git a/fastapi_app/services/deep_research_report_service.py b/fastapi_app/services/deep_research_report_service.py index 8ea98d9..37790a8 100644 --- a/fastapi_app/services/deep_research_report_service.py +++ b/fastapi_app/services/deep_research_report_service.py @@ -9,7 +9,7 @@ import httpx -from dataflow_agent.logger import get_logger +from workflow_engine.logger import get_logger log = get_logger(__name__) diff --git a/fastapi_app/services/fast_research_service.py b/fastapi_app/services/fast_research_service.py index 0e262d1..506b364 100644 --- a/fastapi_app/services/fast_research_service.py +++ b/fastapi_app/services/fast_research_service.py @@ -12,9 +12,9 @@ import httpx -from dataflow_agent.logger import get_logger +from workflow_engine.logger import get_logger -from dataflow_agent.toolkits.research_tools import ( +from workflow_engine.toolkits.research_tools import ( serpapi_search, google_cse_search, brave_search, diff --git a/fastapi_app/services/flashcard_service.py b/fastapi_app/services/flashcard_service.py index 21914f3..6cdd19d 100644 --- a/fastapi_app/services/flashcard_service.py +++ b/fastapi_app/services/flashcard_service.py @@ -9,7 +9,7 @@ from typing import List, Dict, Any from pathlib import Path -from dataflow_agent.logger import get_logger +from workflow_engine.logger import get_logger from fastapi_app.schemas import Flashcard log = get_logger(__name__) @@ -122,6 +122,46 @@ def _build_flashcard_prompt(text_content: str, language: str, card_count: int) - return prompt +def _try_parse_json_array(json_str: str): + """尝试解析 JSON 数组,失败时逐步回退到最后一个完整对象""" + try: + return json.loads(json_str) + except json.JSONDecodeError: + pass + + brace_depth = 0 + in_string = False + escape = False + candidates = [] + for i, ch in enumerate(json_str): + if escape: + escape = False + continue + if ch == '\\' and in_string: + escape = True + continue + if ch == '"' and not escape: + in_string = not in_string + continue + if in_string: + continue + if ch == '{': + brace_depth += 1 + elif ch == '}': + brace_depth -= 1 + if brace_depth == 0: + candidates.append(i) + + for pos in reversed(candidates): + attempt = json_str[:pos + 1] + ']' + try: + return json.loads(attempt) + except json.JSONDecodeError: + continue + + raise json.JSONDecodeError("No valid JSON array found", json_str, 0) + + def _parse_flashcards_from_llm_response(content: str, card_count: int) -> List[Flashcard]: """ 解析 LLM 返回的闪卡数据 @@ -135,11 +175,16 @@ def _parse_flashcards_from_llm_response(content: str, card_count: int) -> List[F """ try: # 提取 JSON(处理可能的 markdown 代码块) - json_match = re.search(r'\[.*\]', content, re.DOTALL) + json_match = re.search(r'```(?:json)?\s*(\[[\s\S]*)', content) if json_match: - flashcards_data = json.loads(json_match.group()) + json_str = json_match.group(1) + json_str = re.sub(r'\s*```\s*$', '', json_str) else: - flashcards_data = json.loads(content) + # fallback: 找 [ 开头的内容 + idx = content.find('[') + json_str = content[idx:] if idx >= 0 else content.strip() + + flashcards_data = _try_parse_json_array(json_str) # 转换为 Flashcard 对象 flashcards = [] diff --git a/fastapi_app/services/paper2drawio_service.py b/fastapi_app/services/paper2drawio_service.py index 3ea7471..a05dd76 100644 --- a/fastapi_app/services/paper2drawio_service.py +++ b/fastapi_app/services/paper2drawio_service.py @@ -15,10 +15,10 @@ from fastapi import Request, UploadFile -from dataflow_agent.state import Paper2DrawioState, Paper2DrawioRequest -from dataflow_agent.toolkits.drawio_tools import wrap_xml, extract_cells -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root +from workflow_engine.state import Paper2DrawioState, Paper2DrawioRequest +from workflow_engine.toolkits.drawio_tools import wrap_xml, extract_cells +from workflow_engine.logger import get_logger +from workflow_engine.utils import get_project_root log = get_logger(__name__) @@ -92,7 +92,7 @@ async def generate_diagram( result_path=str(run_dir), ) - from dataflow_agent.workflow.registry import RuntimeRegistry + from workflow_engine.workflow.registry import RuntimeRegistry try: async with task_semaphore: @@ -162,7 +162,7 @@ async def generate_diagram_from_image( state.temp_data = state.temp_data or {} state.temp_data["sam3_cache_dir"] = sam3_cache_dir - from dataflow_agent.workflow.registry import RuntimeRegistry + from workflow_engine.workflow.registry import RuntimeRegistry try: async with task_semaphore: @@ -256,7 +256,7 @@ async def chat_edit( text_content=message, ) - from dataflow_agent.workflow.registry import RuntimeRegistry + from workflow_engine.workflow.registry import RuntimeRegistry try: async with task_semaphore: diff --git a/fastapi_app/services/paper2ppt_service.py b/fastapi_app/services/paper2ppt_service.py index 4dc2131..942d2c2 100644 --- a/fastapi_app/services/paper2ppt_service.py +++ b/fastapi_app/services/paper2ppt_service.py @@ -22,8 +22,8 @@ run_paper2ppt_full_pipeline, run_paper2ppt_wf_api, ) -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root +from workflow_engine.logger import get_logger +from workflow_engine.utils import get_project_root log = get_logger(__name__) diff --git a/fastapi_app/services/quiz_service.py b/fastapi_app/services/quiz_service.py index 6c1539c..180e429 100644 --- a/fastapi_app/services/quiz_service.py +++ b/fastapi_app/services/quiz_service.py @@ -9,7 +9,7 @@ from typing import List, Dict, Any from pathlib import Path -from dataflow_agent.logger import get_logger +from workflow_engine.logger import get_logger from fastapi_app.schemas import QuizQuestion, QuizOption log = get_logger(__name__) @@ -100,43 +100,30 @@ def _build_quiz_prompt(text_content: str, language: str, question_count: int) -> {text_content} 出题要求: -1. 题目类型:单选题,每题必须有且仅有 4 个选项(A、B、C、D) -2. 题目质量: - - 考察对文档内容的理解和应用,而非简单记忆 - - 题目表述清晰、准确、无歧义 - - 选项设计合理,干扰项要有一定迷惑性 - - 正确答案必须明确且有据可依 -3. 难度分布: - - 简单题(理解):30% - - 中等题(应用):50% - - 困难题(分析):20% -4. 答案解释: - - 必须给出详细的答案解释 - - 解释要引用文档中的具体内容 - - 说明为什么其他选项是错误的 - -请以 JSON 格式返回,格式如下: +1. 题目类型:单选题,每题 4 个选项(A、B、C、D) +2. 考察理解和应用,题目清晰无歧义,干扰项有迷惑性 +3. 难度分布:简单 30%、中等 50%、困难 20% +4. explanation 字段:1-2 句话简要说明正确答案的理由,不要逐个分析错误选项 + +请严格按以下 JSON 格式返回(不要添加额外字段): ```json [ {{ "id": "q1", "question": "题目内容", "options": [ - {{"label": "A", "text": "选项A内容"}}, - {{"label": "B", "text": "选项B内容"}}, - {{"label": "C", "text": "选项C内容"}}, - {{"label": "D", "text": "选项D内容"}} + {{"label": "A", "text": "选项A"}}, + {{"label": "B", "text": "选项B"}}, + {{"label": "C", "text": "选项C"}}, + {{"label": "D", "text": "选项D"}} ], "correct_answer": "A", - "explanation": "详细的答案解释,说明为什么A是正确的,以及为什么其他选项是错误的。", - "source_excerpt": "文档中相关的原文摘录", - "difficulty": "medium", - "category": "application" + "explanation": "简短解释(1-2句话)" }} ] ``` -请确保返回的是有效的 JSON 格式。""" +请确保返回完整、有效的 JSON。""" else: prompt = f"""Based on the following document content, generate {question_count} high-quality multiple-choice quiz questions. @@ -145,84 +132,110 @@ def _build_quiz_prompt(text_content: str, language: str, question_count: int) -> Requirements: 1. Question Type: Multiple choice, each question must have exactly 4 options (A, B, C, D) -2. Quality Standards: - - Test understanding and application, not just memorization - - Questions should be clear, precise, and unambiguous - - Options should be well-designed with plausible distractors - - Correct answer must be definitive and evidence-based -3. Difficulty Distribution: - - Easy (comprehension): 30% - - Medium (application): 50% - - Hard (analysis): 20% -4. Answer Explanation: - - Provide detailed explanation for the correct answer - - Reference specific content from the document - - Explain why other options are incorrect - -Return in JSON format: +2. Test understanding and application, not just memorization. Clear and unambiguous. +3. Difficulty Distribution: Easy 30%, Medium 50%, Hard 20% +4. explanation field: 1-2 sentences briefly explaining why the answer is correct. Do NOT analyze each wrong option. + +Return strictly in this JSON format (no extra fields): ```json [ {{ "id": "q1", "question": "Question text", "options": [ - {{"label": "A", "text": "Option A text"}}, - {{"label": "B", "text": "Option B text"}}, - {{"label": "C", "text": "Option C text"}}, - {{"label": "D", "text": "Option D text"}} + {{"label": "A", "text": "Option A"}}, + {{"label": "B", "text": "Option B"}}, + {{"label": "C", "text": "Option C"}}, + {{"label": "D", "text": "Option D"}} ], "correct_answer": "A", - "explanation": "Detailed explanation of why A is correct and why other options are incorrect.", - "source_excerpt": "Relevant excerpt from the document", - "difficulty": "medium", - "category": "application" + "explanation": "Brief explanation (1-2 sentences)" }} ] ``` -Ensure the response is valid JSON format.""" +Ensure the response is complete, valid JSON.""" return prompt +def _try_parse_json_array(json_str: str): + """尝试解析 JSON 数组,失败时逐步回退到最后一个完整对象""" + # 先直接尝试 + try: + return json.loads(json_str) + except json.JSONDecodeError: + pass + + # 找所有顶层 '}' 的位置(每个代表一个 question 对象的结尾) + # 从后往前逐个尝试截断 + 闭合 + brace_depth = 0 + bracket_depth = 0 + in_string = False + escape = False + candidates = [] + for i, ch in enumerate(json_str): + if escape: + escape = False + continue + if ch == '\\' and in_string: + escape = True + continue + if ch == '"' and not escape: + in_string = not in_string + continue + if in_string: + continue + if ch == '{': + brace_depth += 1 + elif ch == '}': + brace_depth -= 1 + if brace_depth == 0: + candidates.append(i) + elif ch == '[': + bracket_depth += 1 + elif ch == ']': + bracket_depth -= 1 + + # 从最后一个完整对象往前尝试 + for pos in reversed(candidates): + attempt = json_str[:pos + 1] + ']' + try: + return json.loads(attempt) + except json.JSONDecodeError: + continue + + raise json.JSONDecodeError("No valid JSON array found", json_str, 0) + + def _parse_quiz_from_llm_response(content: str, question_count: int) -> List[QuizQuestion]: """ 从 LLM 返回的内容中解析 Quiz 题目 """ try: # 尝试提取 JSON(可能包含在 markdown 代码块中) - json_match = re.search(r'```(?:json)?\s*(\[[\s\S]*?\])\s*```', content) + # 用贪婪匹配,因为内容可能被截断没有闭合的 ``` + json_match = re.search(r'```(?:json)?\s*(\[[\s\S]*)', content) if json_match: json_str = json_match.group(1) + # 去掉尾部可能的 ``` + json_str = re.sub(r'\s*```\s*$', '', json_str) else: - # 尝试直接解析整个内容 json_str = content.strip() - # 尝试修复常见的 JSON 格式问题 - # 1. 移除可能的尾部不完整内容 - if not json_str.endswith(']'): - # 找到最后一个完整的对象 - last_complete = json_str.rfind('}') - if last_complete > 0: - json_str = json_str[:last_complete + 1] + ']' - - # 解析 JSON - questions_data = json.loads(json_str) + questions_data = _try_parse_json_array(json_str) # 转换为 QuizQuestion 对象 questions = [] for i, q_data in enumerate(questions_data[:question_count]): - # 确保有 4 个选项 options = [] for opt in q_data.get("options", [])[:4]: options.append(QuizOption( label=opt.get("label", ""), text=opt.get("text", "") )) - - # 如果选项不足 4 个,补充空选项 while len(options) < 4: - label = chr(65 + len(options)) # A, B, C, D + label = chr(65 + len(options)) options.append(QuizOption(label=label, text="")) question = QuizQuestion( @@ -231,12 +244,13 @@ def _parse_quiz_from_llm_response(content: str, question_count: int) -> List[Qui options=options, correct_answer=q_data.get("correct_answer", "A"), explanation=q_data.get("explanation", ""), - source_excerpt=q_data.get("source_excerpt"), - difficulty=q_data.get("difficulty", "medium"), - category=q_data.get("category", "application") ) questions.append(question) + if not questions: + raise Exception("解析后题目列表为空") + + log.info(f"[quiz_service] 成功解析 {len(questions)} 道题目(请求 {question_count} 道)") return questions except Exception as e: diff --git a/fastapi_app/services/search_and_add_service.py b/fastapi_app/services/search_and_add_service.py new file mode 100644 index 0000000..b0095a4 --- /dev/null +++ b/fastapi_app/services/search_and_add_service.py @@ -0,0 +1,235 @@ +""" +Search & Add 服务 +简单的 Web 搜索 + Top10 爬取功能 +""" +import asyncio +import httpx +from typing import List, Dict, Any +from bs4 import BeautifulSoup +from workflow_engine.logger import get_logger + +log = get_logger(__name__) + + +class SearchAndAddService: + """Search & Add 服务""" + + def __init__(self): + self.timeout = 30.0 + + async def search_and_crawl( + self, + query: str, + top_k: int = 10, + search_provider: str = "serper", + search_api_key: str = None, + ) -> Dict[str, Any]: + """ + 搜索并爬取 Top K 结果 + + Args: + query: 搜索查询 + top_k: 返回前 K 个结果 + search_provider: 搜索引擎提供商 + search_api_key: 搜索 API 密钥 + + Returns: + { + "success": bool, + "query": str, + "sources": List[{ + "title": str, + "url": str, + "snippet": str, + "content": str, # 爬取的完整内容 + "crawl_success": bool + }] + } + """ + log.info(f"[SearchAndAdd] 开始搜索: {query}, top_k={top_k}") + + try: + # 1. 执行搜索 + search_results = await self._search( + query, top_k, search_provider, search_api_key + ) + + if not search_results: + return { + "success": False, + "query": query, + "sources": [], + "error": "搜索未返回结果" + } + + # 2. 并发爬取所有结果 + crawl_tasks = [ + self._crawl_url(result["url"], result["title"]) + for result in search_results[:top_k] + ] + crawled_contents = await asyncio.gather(*crawl_tasks, return_exceptions=True) + + # 3. 合并搜索结果和爬取内容 + sources = [] + for i, result in enumerate(search_results[:top_k]): + crawl_result = crawled_contents[i] + + if isinstance(crawl_result, Exception): + log.warning(f"[SearchAndAdd] 爬取失败 {result['url']}: {crawl_result}") + sources.append({ + **result, + "content": result["snippet"], # 降级使用摘要 + "crawl_success": False + }) + else: + sources.append({ + **result, + "content": crawl_result, + "crawl_success": True + }) + + log.info(f"[SearchAndAdd] 完成,成功爬取 {sum(s['crawl_success'] for s in sources)}/{len(sources)} 个页面") + + return { + "success": True, + "query": query, + "sources": sources + } + + except Exception as e: + log.error(f"[SearchAndAdd] 执行失败: {e}") + return { + "success": False, + "query": query, + "sources": [], + "error": str(e) + } + + async def _search( + self, + query: str, + top_k: int, + provider: str, + api_key: str = None + ) -> List[Dict[str, str]]: + """执行搜索""" + if provider == "serper": + return await self._search_serper(query, top_k, api_key) + else: + raise ValueError(f"不支持的搜索提供商: {provider}") + + async def _search_serper( + self, + query: str, + top_k: int, + api_key: str = None + ) -> List[Dict[str, str]]: + """使用 Serper API 搜索""" + import os + api_key = api_key or os.getenv("SERPER_API_KEY") + + if not api_key: + raise ValueError("SERPER_API_KEY 未配置") + + url = "https://google.serper.dev/search" + headers = { + "X-API-KEY": api_key, + "Content-Type": "application/json" + } + payload = { + "q": query, + "num": top_k + } + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post(url, json=payload, headers=headers) + response.raise_for_status() + data = response.json() + + results = [] + for item in data.get("organic", [])[:top_k]: + results.append({ + "title": item.get("title", ""), + "url": item.get("link", ""), + "snippet": item.get("snippet", "") + }) + + return results + + async def _crawl_url(self, url: str, title: str) -> str: + """ + 爬取单个 URL 的内容 + + Args: + url: 目标 URL + title: 页面标题 + + Returns: + 爬取的文本内容(Markdown 格式) + """ + try: + async with httpx.AsyncClient( + timeout=self.timeout, + follow_redirects=True, + headers={ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" + } + ) as client: + response = await client.get(url) + response.raise_for_status() + + # 解析 HTML + soup = BeautifulSoup(response.text, "html.parser") + + # 移除脚本和样式 + for script in soup(["script", "style", "nav", "footer", "header"]): + script.decompose() + + # 提取主要内容 + # 优先查找常见的内容容器 + main_content = None + for selector in ["article", "main", ".content", "#content", ".post", ".entry"]: + main_content = soup.select_one(selector) + if main_content: + break + + if not main_content: + main_content = soup.body + + if not main_content: + return f"# {title}\n\n无法提取页面内容" + + # 提取文本 + text = main_content.get_text(separator="\n", strip=True) + + # 清理多余空行 + lines = [line.strip() for line in text.split("\n") if line.strip()] + text = "\n\n".join(lines) + + # 限制长度(避免过长) + max_chars = 50000 + if len(text) > max_chars: + text = text[:max_chars] + "\n\n...(内容已截断)" + + # 格式化为 Markdown + markdown = f"# {title}\n\n**Source:** {url}\n\n---\n\n{text}" + + return markdown + + except Exception as e: + log.error(f"[SearchAndAdd] 爬取 {url} 失败: {e}") + raise + + def format_sources_as_markdown(self, sources: List[Dict[str, Any]]) -> str: + """将多个来源格式化为单个 Markdown 文档""" + md_parts = [] + + for i, source in enumerate(sources, 1): + md_parts.append(f"# Source {i}: {source['title']}") + md_parts.append(f"\n**URL:** {source['url']}") + md_parts.append(f"\n**Crawl Status:** {'✓ Success' if source['crawl_success'] else '✗ Failed'}") + md_parts.append("\n---\n") + md_parts.append(source['content']) + md_parts.append("\n\n" + "="*80 + "\n\n") + + return "".join(md_parts) diff --git a/fastapi_app/source_manager.py b/fastapi_app/source_manager.py index b1678e3..5778163 100644 --- a/fastapi_app/source_manager.py +++ b/fastapi_app/source_manager.py @@ -18,8 +18,8 @@ from pathlib import Path from typing import List, Optional, Tuple -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root +from workflow_engine.logger import get_logger +from workflow_engine.utils import get_project_root from fastapi_app.notebook_paths import NotebookPaths @@ -298,7 +298,7 @@ def list_sources(self) -> List[SourceInfo]: async def _run_mineru(self, pdf_path: Path, output_dir: Path) -> None: """Run MinerU on a PDF file.""" - from dataflow_agent.toolkits.multimodaltool.mineru_tool import run_mineru_pdf_extract + from workflow_engine.toolkits.multimodaltool.mineru_tool import run_mineru_pdf_extract await asyncio.to_thread( run_mineru_pdf_extract, str(pdf_path), diff --git a/fastapi_app/utils.py b/fastapi_app/utils.py index 1b0a145..743b149 100644 --- a/fastapi_app/utils.py +++ b/fastapi_app/utils.py @@ -7,8 +7,8 @@ from fastapi import HTTPException, Request -from dataflow_agent.logger import get_logger -from dataflow_agent.utils import get_project_root +from workflow_engine.logger import get_logger +from workflow_engine.utils import get_project_root log = get_logger(__name__) diff --git a/fastapi_app/utils/error_handler.py b/fastapi_app/utils/error_handler.py new file mode 100644 index 0000000..cd74e42 --- /dev/null +++ b/fastapi_app/utils/error_handler.py @@ -0,0 +1,29 @@ +from fastapi import HTTPException +from workflow_engine.logger import get_logger + +log = get_logger(__name__) + + +def handle_exception( + e: Exception, + context: str, + user_message: str = "操作失败", + status_code: int = 500, + log_level: str = "error" +) -> HTTPException: + """ + Handle exception with detailed logging and safe user message. + + Args: + e: The exception to handle + context: Context description for logging + user_message: Safe message to return to user + status_code: HTTP status code + log_level: Logging level (error, warning, info, debug) + + Returns: + HTTPException with safe user message + """ + logger_func = getattr(log, log_level) + logger_func(f"{context} 失败: {type(e).__name__}: {str(e)}", exc_info=True) + return HTTPException(status_code=status_code, detail=user_message) diff --git a/fastapi_app/workflow_adapters/wa_data_insight.py b/fastapi_app/workflow_adapters/wa_data_insight.py new file mode 100644 index 0000000..74131d7 --- /dev/null +++ b/fastapi_app/workflow_adapters/wa_data_insight.py @@ -0,0 +1,108 @@ +""" +Data Insight Workflow Adapter +Mandatory isolation layer between Service and Workflow. +""" +from __future__ import annotations +from typing import Dict, Any +from workflow_engine.state import DataInsightState, DataInsightRequest +from workflow_engine.workflow.registry import RuntimeRegistry +from workflow_engine.logger import get_logger + +log = get_logger(__name__) + + +class DataInsightAdapter: + """ + Adapter for data insight workflow. + Converts API request dict to workflow state and executes workflow. + """ + + async def execute(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute data insight workflow. + + Args: + request_data: Dict with keys: + - file_ids: List[str] + - model: str + - api_key: str + - chat_api_url: str + - output_mode: str + - analysis_goal: Optional[str] + - language: str + - email: Optional[str] + + Returns: + Dict with keys: + - status: "success" | "error" + - synthesized_insights: List[str] + - raw_insights: List[str] + - summary: str + - detailed_appendix: Dict (if detailed mode) + - result_path: str + - error: str (if error) + """ + try: + # Build workflow request + wf_request = DataInsightRequest( + file_ids=request_data.get("file_ids", []), + output_mode=request_data.get("output_mode", "concise"), + analysis_goal=request_data.get("analysis_goal"), + model=request_data.get("model", "deepseek-v3.2"), + api_key=request_data.get("api_key", ""), + chat_api_url=request_data.get("chat_api_url", ""), + language=request_data.get("language", "en") + ) + + # Add email if provided + if request_data.get("email"): + wf_request.email = request_data["email"] + + # Build workflow state + state = DataInsightState(request=wf_request) + + # Execute workflow + log.info("Executing data_insight workflow") + factory = RuntimeRegistry.get("data_insight") + builder = factory() + graph = builder.build() + + result_state = await graph.ainvoke(state) + + # Handle both dict and dataclass returns + if isinstance(result_state, dict): + # Result is a dict + synthesized_insights = result_state.get("synthesized_insights", []) + raw_insights = result_state.get("raw_insights", []) + summary = result_state.get("summary", "") + detailed_appendix = result_state.get("detailed_appendix", {}) + result_path = result_state.get("result_path", "") + else: + # Result is a DataInsightState object + synthesized_insights = result_state.synthesized_insights + raw_insights = result_state.raw_insights + summary = result_state.summary + detailed_appendix = result_state.detailed_appendix + result_path = result_state.result_path + + # Format response + return { + "status": "success", + "synthesized_insights": synthesized_insights, + "raw_insights": raw_insights, + "summary": summary, + "detailed_appendix": detailed_appendix, + "result_path": result_path + } + + except Exception as e: + log.error(f"Adapter execution failed: {e}", exc_info=True) + return { + "status": "error", + "error": str(e), + "synthesized_insights": [], + "raw_insights": [], + "summary": f"Analysis failed: {str(e)}", + "detailed_appendix": {}, + "result_path": "" + } diff --git a/fastapi_app/workflow_adapters/wa_paper2ppt.py b/fastapi_app/workflow_adapters/wa_paper2ppt.py index 4bbcf3d..1e7b79b 100644 --- a/fastapi_app/workflow_adapters/wa_paper2ppt.py +++ b/fastapi_app/workflow_adapters/wa_paper2ppt.py @@ -15,11 +15,11 @@ from pathlib import Path from typing import Any, List -from dataflow_agent.logger import get_logger -from dataflow_agent.state import Paper2FigureState -from dataflow_agent.toolkits.multimodaltool.mineru_tool import _shrink_markdown -from dataflow_agent.utils import get_project_root -from dataflow_agent.workflow import run_workflow +from workflow_engine.logger import get_logger +from workflow_engine.state import Paper2FigureState +from workflow_engine.toolkits.multimodaltool.mineru_tool import _shrink_markdown +from workflow_engine.utils import get_project_root +from workflow_engine.workflow import run_workflow from fastapi_app.notebook_paths import get_notebook_paths from fastapi_app.schemas import Paper2PPTRequest, Paper2PPTResponse diff --git a/fastapi_app/workflow_adapters_old.py b/fastapi_app/workflow_adapters_old.py index 3132405..090cc0f 100644 --- a/fastapi_app/workflow_adapters_old.py +++ b/fastapi_app/workflow_adapters_old.py @@ -16,15 +16,15 @@ # import json # import time -# from dataflow_agent.state import Paper2FigureState, DFRequest, DFState, Paper2FigureRequest as DF_Paper2FigureRequest -# from dataflow_agent.workflow import run_workflow -# from dataflow_agent.logger import get_logger -# from dataflow_agent.state import Paper2VideoRequest, Paper2VideoState -# from dataflow_agent.utils import get_project_root -# from dataflow_agent.workflow.wf_pipeline_recommend_extract_json import ( +# from workflow_engine.state import Paper2FigureState, DFRequest, DFState, Paper2FigureRequest as DF_Paper2FigureRequest +# from workflow_engine.workflow import run_workflow +# from workflow_engine.logger import get_logger +# from workflow_engine.state import Paper2VideoRequest, Paper2VideoState +# from workflow_engine.utils import get_project_root +# from workflow_engine.workflow.wf_pipeline_recommend_extract_json import ( # create_pipeline_graph, # ) -# from dataflow_agent.workflow.wf_pipeline_write import create_operator_write_graph +# from workflow_engine.workflow.wf_pipeline_write import create_operator_write_graph # from .schemas import ( # OperatorWriteRequest, @@ -330,7 +330,7 @@ # ) # state = Paper2VideoState(request=req, messages=[]) -# from dataflow_agent.workflow.wf_paper2video import create_paper2video_graph +# from workflow_engine.workflow.wf_paper2video import create_paper2video_graph # graph = create_paper2video_graph().build() # final_state: Paper2VideoState = await graph.ainvoke(state) diff --git a/frontend/.vite/deps/_metadata.json b/frontend/.vite/deps/_metadata.json deleted file mode 100644 index e6e0fab..0000000 --- a/frontend/.vite/deps/_metadata.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "hash": "9318ff93", - "configHash": "ba56d276", - "lockfileHash": "e3b0c442", - "browserHash": "3edd464a", - "optimized": {}, - "chunks": {} -} \ No newline at end of file diff --git a/frontend/.vite/deps/package.json b/frontend/.vite/deps/package.json deleted file mode 100644 index 3dbc1ca..0000000 --- a/frontend/.vite/deps/package.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "type": "module" -} diff --git a/frontend_en/Dockerfile b/frontend_en/Dockerfile index a1bee6e..022ef93 100644 --- a/frontend_en/Dockerfile +++ b/frontend_en/Dockerfile @@ -9,7 +9,7 @@ COPY frontend-v2/ ./ ARG VITE_API_KEY=df-internal-2024-workflow-key ARG VITE_DEFAULT_LLM_API_URL=https://api.apiyi.com/v1 -ARG VITE_LLM_API_URLS=https://api.apiyi.com/v1,http://b.apiyi.com:16888/v1,http://123.129.219.111:3000/v1 +ARG VITE_LLM_API_URLS=https://api.apiyi.com/v1,http://b.apiyi.com:16888/v1,http://172.96.160.199:3000/v1 ARG VITE_API_BASE_URL= ENV VITE_API_KEY=$VITE_API_KEY \ diff --git a/frontend_en/INTEGRATION_GUIDE.md b/frontend_en/INTEGRATION_GUIDE.md deleted file mode 100644 index 1ea2f33..0000000 --- a/frontend_en/INTEGRATION_GUIDE.md +++ /dev/null @@ -1,340 +0,0 @@ -# NotebookLM v2 知识库集成指南 - -## 📋 概述 - -本项目成功将 `frontend-workflow` 知识库的后端功能集成到 `frontend-v2` (NotebookLM) 项目中。 - -**核心目标**: -- ✅ 保持 Notebook 的前端界面风格 -- ✅ 接入知识库的后端 API -- ✅ 支持文件上传、智能问答、PPT生成、思维导图等功能 -- ✅ 不改动后端逻辑,前端适配后端 - ---- - -## 🔧 完成的工作 - -### 1. 项目配置 - -#### 1.1 创建的配置文件 - -- **`vite.config.ts`** - Vite 配置,包含后端代理 - - 端口: 3001 - - 后端代理: `http://localhost:8210` - -- **`src/config/api.ts`** - API 配置文件 - - API Key 管理 - - apiFetch 封装函数 - -- **`src/lib/supabase.ts`** - Supabase 客户端 - - 支持认证和数据库操作 - - 未配置时自动降级 - -- **`src/stores/authStore.ts`** - 认证状态管理 - - 使用 Zustand 管理用户状态 - - 支持 session 管理 - -- **`src/services/apiSettingsService.ts`** - API 设置服务 - - 管理 LLM API 配置 - -#### 1.2 依赖更新 - -在 `package.json` 中添加了: -```json -{ - "@supabase/supabase-js": "^2.89.0", - "mermaid": "^10.6.1", - "zustand": "^4.4.7" -} -``` - -### 2. 核心组件重构 - -#### 2.1 NotebookView.tsx (二级界面) - -**原有功能**:静态展示的笔记本界面 - -**新增功能**: -1. **文件管理** - - 从 Supabase 获取用户上传的文件 - - 支持文件选择(多选) - - 支持文件上传 - -2. **智能问答** - - 基于选中文件的 RAG 问答 - - 显示历史对话 - - 实时流式响应 - -3. **Studio 工具** - - PPT 生成 - - 思维导图生成(支持 Mermaid 渲染) - - 知识播客生成 - - 视频讲解生成 - - 语义检索 - -4. **UI 保持** - - 保留了原有的 NotebookLM 风格 - - 三栏布局(来源、对话、工具) - - Tab 切换(对话、检索、来源管理) - -#### 2.2 App.tsx - -添加了认证初始化逻辑: -- 如果 Supabase 已配置,使用真实认证 -- 如果未配置,创建模拟用户(方便开发) - -#### 2.3 Dashboard.tsx (一级界面) - -保持原有设计,展示: -- 精选笔记本(写死的数据) -- 最近打开的笔记本(写死的数据) -- 新建笔记本入口 - -### 3. 工具组件 - -#### 3.1 MermaidPreview.tsx - -从 `frontend-workflow` 复制的思维导图预览组件: -- 支持 Mermaid 代码渲染 -- 支持放大预览 -- 支持下载 SVG 和源代码 -- 支持编辑和实时预览 - -### 4. 类型定义 - -更新 `src/types/index.ts`: -```typescript -export type MaterialType = 'image' | 'doc' | 'video' | 'link' | 'audio'; -export interface KnowledgeFile { ... } -export interface ChatMessage { ... } -export type SectionType = 'library' | 'upload' | 'output' | 'settings'; -export type ToolType = 'chat' | 'ppt' | 'mindmap' | 'podcast' | 'video' | 'search'; -``` - ---- - -## 🚀 使用指南 - -### 启动项目 - -```bash -cd /data/users/szl/opennotebook/Paper2Any/frontend-v2 - -# 安装依赖(已完成) -npm install - -# 启动开发服务器 -npm run dev -``` - -访问: `http://localhost:3001` - -### 环境变量配置 - -如果需要使用 Supabase,创建 `.env` 文件: - -```env -VITE_SUPABASE_URL=your_supabase_url -VITE_SUPABASE_ANON_KEY=your_anon_key -VITE_API_KEY=df-internal-2024-workflow-key -``` - -如果不配置,将使用模拟用户。 - -### 后端服务 - -确保后端服务运行在 `http://localhost:8210` - -主要 API 端点: -- `POST /api/v1/kb/upload` - 文件上传 -- `POST /api/v1/kb/chat` - 智能问答 -- `POST /api/v1/kb/mindmap` - 思维导图生成 -- `POST /api/v1/kb/ppt` - PPT 生成 -- `POST /api/v1/kb/podcast` - 播客生成 - ---- - -## 📁 项目结构 - -``` -frontend-v2/ -├── src/ -│ ├── components/ -│ │ └── knowledge-base/ -│ │ └── tools/ -│ │ └── MermaidPreview.tsx # 思维导图组件 -│ ├── config/ -│ │ └── api.ts # API 配置 -│ ├── lib/ -│ │ └── supabase.ts # Supabase 客户端 -│ ├── pages/ -│ │ ├── Dashboard.tsx # 一级界面 -│ │ └── NotebookView.tsx # 二级界面(核心) -│ ├── services/ -│ │ └── apiSettingsService.ts # API 设置 -│ ├── stores/ -│ │ └── authStore.ts # 认证状态 -│ ├── types/ -│ │ └── index.ts # 类型定义 -│ ├── App.tsx # 主应用 -│ └── main.tsx # 入口 -├── vite.config.ts # Vite 配置 -├── package.json # 依赖配置 -└── README.md # 项目说明 -``` - ---- - -## 🔄 适配逻辑 - -### 前端适配后端 - -1. **文件上传** - - 前端: 通过 `` 上传 - - 后端: `POST /api/v1/kb/upload` - - 数据库: 保存到 Supabase `knowledge_base_files` 表 - -2. **文件管理** - - 前端: 从 Supabase 读取用户的文件列表 - - 显示: 左侧来源面板 - - 选择: 支持多选,自动选中所有文件 - -3. **智能问答** - - 输入: 用户问题 + 选中的文件 - - 后端: `POST /api/v1/kb/chat` - - 参数: `{ files, query, history, api_url, api_key }` - - 响应: `{ answer, file_analyses }` - -4. **工具生成** - - 思维导图: `POST /api/v1/kb/mindmap` -> 返回 `mindmap_code` - - PPT: `POST /api/v1/kb/ppt` -> 返回 `ppt_url` - - 播客: `POST /api/v1/kb/podcast` -> 返回 `audio_url` - -### 数据流 - -``` -用户上传文件 - ↓ -前端调用 /api/v1/kb/upload - ↓ -后端保存文件,返回 URL - ↓ -前端保存到 Supabase - ↓ -用户选择文件,发起问答/生成 - ↓ -前端调用相应 API,传入选中文件的 URL - ↓ -后端处理,返回结果 - ↓ -前端展示(对话框/工具面板) -``` - ---- - -## ✅ 功能清单 - -### 已实现功能 - -- [x] 文件上传(支持 PDF, DOCX, PPTX, 图片等) -- [x] 文件列表展示 -- [x] 文件多选 -- [x] 智能问答(RAG) -- [x] 思维导图生成(Mermaid 渲染) -- [x] PPT 生成 -- [x] 播客生成 -- [x] 前端样式保持 NotebookLM 风格 -- [x] 认证状态管理(支持 Supabase 或模拟用户) -- [x] 后端 API 集成 - -### 待完善功能 - -- [ ] 来源管理(删除、重新索引等) -- [ ] 多模态检索功能实现 -- [ ] 视频讲解生成完整实现 -- [ ] 一级界面(Dashboard)接入真实数据 -- [ ] API 设置界面 -- [ ] 错误处理优化 -- [ ] 加载状态优化 - ---- - -## 🎯 核心原则 - -1. **前端适配后端** - - 不修改后端 API - - 前端调用现有接口 - - 数据格式遵循后端规范 - -2. **保持 UI 风格** - - Notebook 的界面设计 - - 三栏布局 - - 原有的交互逻辑 - -3. **复用知识库功能** - - 文件管理 - - 智能工具 - - 认证系统 - ---- - -## 📝 注意事项 - -1. **后端依赖** - - 必须运行 Paper2Any 后端服务 - - 默认端口: 8210 - - 确保后端 API 可访问 - -2. **数据库** - - 如果使用 Supabase,需要配置环境变量 - - 表结构: `knowledge_base_files` - - 字段: user_id, file_name, file_type, storage_path 等 - -3. **开发模式** - - 未配置 Supabase 时,使用模拟用户 - - 模拟用户 ID: `dev-user-001` - - 适合纯前端开发 - -4. **生产环境** - - 必须配置 Supabase - - 必须配置 LLM API - - 需要真实的认证系统 - ---- - -## 🐛 常见问题 - -### Q: 上传文件失败? -A: 检查后端服务是否运行,检查网络代理配置 - -### Q: 无法登录? -A: 如果未配置 Supabase,会自动使用模拟用户 - -### Q: 生成工具无响应? -A: 检查是否选中了文件,检查后端 API 状态 - -### Q: 思维导图不显示? -A: 检查 mermaid 依赖是否安装,检查返回的代码格式 - ---- - -## 📞 技术支持 - -如有问题,请检查: -1. 后端服务日志 -2. 浏览器控制台错误 -3. 网络请求状态 -4. 环境变量配置 - ---- - -## 🎉 总结 - -本次集成成功将知识库的强大后端功能与 NotebookLM 优雅的前端界面结合,实现了: - -- ✅ **无缝集成** - 前端完全适配后端 API -- ✅ **功能完整** - 支持所有核心知识库工具 -- ✅ **开发友好** - 支持模拟用户,方便调试 -- ✅ **可扩展** - 易于添加新功能 - -现在您可以使用 NotebookLM 的界面,享受知识库的全部功能! diff --git a/frontend_en/PREVIEW_FEATURE.md b/frontend_en/PREVIEW_FEATURE.md deleted file mode 100644 index 30599f3..0000000 --- a/frontend_en/PREVIEW_FEATURE.md +++ /dev/null @@ -1,139 +0,0 @@ -# 产出预览功能说明 - -## ✅ 已完成的功能 - -### 1. 产出内容信息流 - -在右侧 Studio 面板下方显示"产出内容"列表: - -- **PPT 生成** - 显示来源文件,提供预览和下载按钮 -- **思维导图** - 显示来源文件,提供预览和下载按钮 -- **播客生成** - 显示来源文件,提供预览和下载按钮 - -每条产出记录包含: -- 标题(工具类型) -- 来源文件列表 -- 生成时间 -- 预览和下载按钮 - ---- - -### 2. 点击预览功能 - -点击产出内容卡片后,会打开全屏预览模态框: - -#### PPT 预览 -- 使用 iframe 直接嵌入 PDF 预览 -- 支持下载 PDF 文件 -- 大屏展示,便于查看内容 - -#### 播客预览 -- 精美的音频播放器界面 -- 支持自动播放 -- 显示播客标题和生成时间 -- 支持下载音频文件 - -#### 思维导图预览 -- 使用 Mermaid 渲染引擎 -- 支持放大缩小 -- 支持查看代码/图形切换 -- 支持下载 SVG 和源代码 - ---- - -## 🎨 界面特点 - -### 产出列表卡片 -- 白色背景,圆角设计 -- hover 悬浮效果 -- 点击整个卡片即可预览 -- 独立的下载按钮 - -### 预览模态框 -- 全屏半透明背景 -- 居中展示,最大化利用空间 -- 支持点击外部关闭 -- 根据不同类型展示不同内容 - ---- - -## 📋 使用流程 - -1. **上传文件** → 左侧来源列表 -2. **选择文件** → 勾选需要的文件 -3. **点击工具卡片** → PPT/思维导图/播客 -4. **等待生成** → 右侧显示加载状态 -5. **查看产出** → 产出内容列表自动新增 -6. **点击预览** → 打开预览界面 -7. **下载文件** → 点击下载按钮 - ---- - -## 🔧 技术实现 - -### 状态管理 -```typescript -const [outputFeed, setOutputFeed] = useState>([]); - -const [previewOutput, setPreviewOutput] = useState<...>(null); -``` - -### 产出记录 -生成成功后自动追加到 `outputFeed`: -- 从后端返回值提取 URL -- 保存来源文件信息 -- 思维导图额外保存 mermaid 代码 - -### 预览渲染 -- **PPT**:`