diff --git a/.gitignore b/.gitignore index 9248013..59734d7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ AGENTS.md CLAUDE.md # Local files +data/ docs/ logs/ standalone_tagger.py diff --git a/README.md b/README.md index d68effd..5f960f9 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,7 @@ ![Tests](https://github.com/ComfyAssets/ComfyUI_PromptManager/workflows/Tests/badge.svg) ![Code Quality](https://github.com/ComfyAssets/ComfyUI_PromptManager/workflows/Code%20Quality/badge.svg) -A comprehensive ComfyUI custom node that extends the standard text encoder with persistent prompt storage, advanced search capabilities, automatic image gallery system, and powerful ComfyUI workflow metadata analysis using SQLite. - -## 📋 A Note on v2 Development - -V2 development will take longer as life has taken me in other directions for the moment. I will still work on it but not as actively as before. In the meantime, I will backport some of the features users have requested from v2 over to v1. Once v2 is ready, I will update here. Thank you for your continued support! - ---- +A comprehensive ComfyUI custom node that extends the standard text encoder with persistent prompt storage, advanced search capabilities, automatic image gallery system, folder-based organization, LoRA Manager integration, and powerful ComfyUI workflow metadata analysis using SQLite. ## Overview @@ -54,6 +48,7 @@ Both nodes include the complete PromptManager feature set: - **💾 Persistent Storage**: Automatically saves all prompts to a local SQLite database - **🔍 Advanced Search**: Query past prompts with text search, category filtering, and metadata +- **📁 Folder Filter**: Browse and filter prompts by output subdirectory - **🖼️ Automatic Image Gallery**: Automatically links generated images to their prompts - **🏷️ Rich Metadata**: Add categories, tags, ratings, notes, and workflow names to prompts - **🚫 Duplicate Prevention**: Uses SHA256 hashing to detect and prevent duplicate storage @@ -62,7 +57,8 @@ Both nodes include the complete PromptManager feature set: - **🔬 Workflow Analysis**: Extract and analyze ComfyUI workflow data from PNG images - **📋 Metadata Viewer**: Standalone tool for analyzing ComfyUI-generated images - **🛠️ System Management**: Built-in diagnostics, backup/restore, and maintenance tools -- **🏷️ AI AutoTag**: Automatically tag your image collection using JoyCaption vision models +- **🏷️ AI AutoTag**: Automatically tag your image collection using WD14 or JoyCaption vision models +- **🔗 LoRA Manager Integration**: Import LoRA metadata and preview images from [ComfyUI-Lora-Manager](https://github.com/willchil/ComfyUI-Lora-Manager) ![Image Gallery](images/pm-02.png) @@ -372,10 +368,11 @@ Import existing ComfyUI images into your database: ### 🏷️ AI AutoTag -Automatically tag your entire image collection using JoyCaption vision models: +Automatically tag your entire image collection using AI vision models: #### **Model Options** +- **WD14 SwinV2 / WD14 ViT**: Fast ONNX-based classifiers (~400MB) that output Danbooru tags with confidence scores — no prompt needed, adjustable confidence thresholds - **JoyCaption Beta One FP16**: Full precision model for highest quality tagging (requires more VRAM) - **JoyCaption Beta One GGUF (FP8)**: Quantized model for lower VRAM usage with minimal quality loss @@ -417,6 +414,61 @@ Adjust the system prompt to match your tagging style: - Models are downloaded automatically on first use - Feature requests and improvements are welcome! +### 📁 Folder Filter + +![Folder Filter](images/pm-lora-import.png) + +Organize and filter your prompt library by output subdirectory. The dashboard search panel includes a **Folder** dropdown that lists all subdirectories where your images are stored. + +- **Automatic Detection**: Folders are extracted from your image paths — no configuration needed +- **Multi-Directory Support**: Works across all configured gallery scan directories +- **Quick Filtering**: Select a folder from the dropdown to instantly filter prompts to that subdirectory + +> **Note:** If you have an existing prompt library from before the folder feature was added, you will need to **rescan your image library** (click **Scan Images** in the admin dashboard) to populate folder data for your existing prompts. + +### 🔗 LoRA Manager Integration + +![LoRA Settings](images/pm-settings-integrations.png) + +If you use [ComfyUI-Lora-Manager](https://github.com/willchil/ComfyUI-Lora-Manager), PromptManager can import your LoRA metadata, trigger words, and example images directly into your prompt database. This lets you search, tag, and browse your LoRA collection alongside your regular prompts. + +> **WIP:** LoRA Manager support is a work in progress — this was a highly requested feature. Please [open issues](https://github.com/ComfyAssets/ComfyUI_PromptManager/issues) for any bugs or feature requests. + +#### **Enabling the Integration** + +1. Open the admin dashboard and click **Settings** +2. Scroll to the **Integrations** section — PromptManager will auto-detect if LoRA Manager is installed + +![LoRA Enabled](images/pm-lora-enabled.png) + +3. Toggle the **LoraManager** switch to enable +4. Optionally enable **Auto-inject trigger words** to automatically append LoRA trigger words when `` is detected in your prompts +5. Click **Save Settings** (requires a ComfyUI restart to take effect) + +#### **Importing LoRA Data** + +Click **Import LoRA Data** in the settings panel to start the import. A progress popup shows real-time status as each LoRA is processed: + +- Scans all LoRA directories (including `extra_model_paths.yaml` paths) for metadata +- Downloads example images from CivitAI as 512px thumbnails +- Creates a prompt entry for each LoRA with trigger words, example prompts, and preview images +- All imported LoRAs are tagged with `lora-manager` category for easy filtering + +![LoRA Results](images/pm-lora-filtered.png) + +After import, filter by the `lora-manager` category to browse your LoRA collection. + +#### **CivitAI API Key (Optional)** + +A CivitAI API key is **required if your library contains NSFW LoRAs** — CivitAI blocks unauthenticated access to NSFW preview images. Without a key, only SFW preview images are downloaded. + +To add your key: +1. Get your API key from [civitai.com/user/account](https://civitai.com/user/account) +2. Paste it into the **CivitAI API Key** field in the Integrations settings +3. Save and re-import to download any previously skipped images + +> **Note:** Re-importing is safe — PromptManager skips LoRAs that were already imported. To do a fresh import, the previous `lora-manager` data is cleared automatically before re-scanning. + ### 🌐 Web Interface Features The comprehensive web interface provides: @@ -600,8 +652,9 @@ CREATE TABLE generated_images ( - **`prompt_search_list.py`** - Batch search node implementation (LIST output) - **`database/models.py`** - Database schema and connection management - **`database/operations.py`** - CRUD operations and search functionality -- **`py/api.py`** - Web API endpoints for the interface +- **`py/api/`** - Web API route modules (prompts, images, tags, lora integration, etc.) - **`py/config.py`** - Configuration management +- **`py/lora_utils.py`** - LoRA Manager integration utilities - **`utils/hashing.py`** - SHA256 hashing for deduplication - **`utils/validators.py`** - Input validation and sanitization - **`utils/image_monitor.py`** - Automatic image detection system @@ -611,7 +664,8 @@ CREATE TABLE generated_images ( - **`utils/diagnostics.py`** - System diagnostics and health checks - **`web/admin.html`** - Advanced admin dashboard with metadata panel - **`web/index.html`** - Simple web interface -- **`web/prompt_manager.js`** - JavaScript functionality +- **`web/js/prompt_manager.js`** - Dashboard JavaScript +- **`web/js/tags-page.js`** - Tag management JavaScript - **`web/metadata.html`** - Standalone PNG metadata viewer ### File Structure @@ -628,8 +682,9 @@ ComfyUI_PromptManager/ │ └── operations.py # Database operations ├── py/ │ ├── __init__.py -│ ├── api.py # Web API endpoints -│ └── config.py # Configuration +│ ├── api/ # Web API route modules +│ ├── config.py # Configuration +│ └── lora_utils.py # LoRA Manager integration ├── utils/ │ ├── __init__.py │ ├── hashing.py # Hashing utilities @@ -641,9 +696,12 @@ ComfyUI_PromptManager/ │ └── diagnostics.py # System diagnostics ├── web/ │ ├── admin.html # Advanced admin dashboard +│ ├── gallery.html # Image gallery │ ├── index.html # Simple web interface │ ├── metadata.html # Standalone metadata viewer -│ └── prompt_manager.js # JavaScript functionality +│ └── js/ +│ ├── prompt_manager.js # Dashboard JavaScript +│ └── tags-page.js # Tag management JavaScript ├── tests/ │ ├── __init__.py │ └── test_basic.py # Test suite @@ -752,7 +810,7 @@ db.model.backup_database("backup_prompts.db") ### Running Tests ```bash -cd KikoTextEncode +cd ComfyUI_PromptManager python -m pytest tests/ -v ``` @@ -797,7 +855,7 @@ The project follows PEP 8 guidelines with: For debugging, you can enable verbose logging in the node: ```python -# Add to kiko_text_encode.py +# Add to prompt_manager.py import logging logging.basicConfig(level=logging.DEBUG) ``` @@ -814,15 +872,15 @@ MIT License - see LICENSE file for details. ## Roadmap -### Completed in v3.0.0 +### Recently Completed -- **✅ PNG Metadata Analysis**: Complete ComfyUI workflow extraction from images -- **✅ Standalone Metadata Viewer**: Dedicated tool for analyzing any ComfyUI image -- **✅ Advanced Admin Dashboard**: Comprehensive management interface with modern UI -- **✅ Integrated Metadata Panel**: Real-time workflow analysis in image viewer -- **✅ Bulk Image Scanning**: Mass import of existing ComfyUI images -- **✅ System Management Tools**: Backup, restore, diagnostics, and maintenance -- **✅ Enhanced Error Handling**: Robust PNG parsing with NaN value cleaning +- **✅ LoRA Manager Integration**: Import LoRA metadata, trigger words, and preview images +- **✅ Folder Filter**: Browse and filter prompts by output subdirectory +- **✅ Multi-Directory Gallery**: Scan multiple output directories simultaneously +- **✅ WD14 Tagger**: Fast ONNX-based auto-tagging with Danbooru tags +- **✅ Tailwind v4 Migration**: ComfyUI theme token system for consistent styling +- **✅ Tag Management Page**: Dedicated page for tag search, rename, merge, and delete +- **✅ Junction Table Tags**: Normalized tag storage with proper foreign keys ### Planned Features @@ -830,18 +888,27 @@ MIT License - see LICENSE file for details. - **🤝 Collaboration**: Share prompt collections with other users - **🧠 AI Suggestions**: Recommend similar prompts based on metadata analysis - **📈 Advanced Analytics**: Detailed usage statistics and trends with workflow insights -- **🔌 Plugin System**: Support for third-party extensions and custom analyzers -- **🎨 Enhanced Batch Processing**: Advanced bulk operations with metadata editing - **🔄 Workflow Templates**: Save and reuse common workflow patterns - **📊 Visual Analytics**: Charts and graphs for prompt effectiveness analysis -### Integration Ideas +## Changelog -- **Workflow linking**: Connect prompts to specific workflow templates -- **Image analysis**: Analyze generated images to improve suggestions -- **Version control**: Track prompt iterations and effectiveness +### v3.2.1 (LoRA Manager Integration) -## Changelog +- **🔗 LoRA Manager Integration**: Import LoRA metadata, trigger words, and CivitAI example images from [ComfyUI-Lora-Manager](https://github.com/willchil/ComfyUI-Lora-Manager) into your prompt database +- **💉 Auto-Inject Trigger Words**: Optionally append LoRA trigger words when `` is detected in prompts during encoding +- **🔑 CivitAI API Key Support**: Authenticate with CivitAI to download NSFW preview images +- **📥 Import Progress Modal**: Real-time SSE streaming progress during LoRA import with per-model status + +> **Note:** LoRA Manager support is a WIP — this was a highly requested feature. Please [open issues](https://github.com/ComfyAssets/ComfyUI_PromptManager/issues) for any bugs or feature requests. + +### v3.2.0 (Folder Filter & QoL Improvements) + +- **📁 Folder Filter**: New folder dropdown in the search panel to filter prompts by output subdirectory +- **📂 Multi-Directory Gallery Scan**: Configure multiple output directories and browse them all in one gallery +- **🖼️ Filmstrip Prompt Display**: Show prompt text and copy button in the filmstrip image viewer +- **🖱️ Click-to-Close Viewer**: Click outside the image viewer to close it +- **🐛 Bug Fixes**: Fixed database path config, UnboundLocalError in text inputs, node caching skip, diagnostics singleton lifecycle ### v3.1.0 (WD14 Tagger, Tailwind v4 & Major Refactors) diff --git a/database/operations.py b/database/operations.py index ad9b6ae..ce48a29 100644 --- a/database/operations.py +++ b/database/operations.py @@ -444,6 +444,38 @@ def delete_prompt(self, prompt_id: int) -> bool: conn.commit() return cursor.rowcount > 0 + def delete_prompts_by_category(self, category: str) -> int: + """Delete all prompts with the given category. + + Returns: + Number of prompts deleted. + """ + with self.model.get_connection() as conn: + # Get IDs first for cascade cleanup + ids = [ + r[0] + for r in conn.execute( + "SELECT id FROM prompts WHERE category = ?", (category,) + ).fetchall() + ] + if not ids: + return 0 + placeholders = ",".join("?" * len(ids)) + conn.execute( + f"DELETE FROM generated_images WHERE prompt_id IN ({placeholders})", + ids, + ) + conn.execute( + f"DELETE FROM prompt_tags WHERE prompt_id IN ({placeholders})", + ids, + ) + cursor = conn.execute( + f"DELETE FROM prompts WHERE id IN ({placeholders})", + ids, + ) + conn.commit() + return cursor.rowcount + def get_all_categories(self) -> List[str]: """ Get all unique categories from the database. diff --git a/images/pm-lora-enabled.png b/images/pm-lora-enabled.png new file mode 100644 index 0000000..a4a2f65 Binary files /dev/null and b/images/pm-lora-enabled.png differ diff --git a/images/pm-lora-filtered.png b/images/pm-lora-filtered.png new file mode 100644 index 0000000..adc1391 Binary files /dev/null and b/images/pm-lora-filtered.png differ diff --git a/images/pm-lora-import.png b/images/pm-lora-import.png new file mode 100644 index 0000000..0e118fd Binary files /dev/null and b/images/pm-lora-import.png differ diff --git a/images/pm-lora-results.png b/images/pm-lora-results.png new file mode 100644 index 0000000..a4e3b08 Binary files /dev/null and b/images/pm-lora-results.png differ diff --git a/images/pm-settings-integrations.png b/images/pm-settings-integrations.png new file mode 100644 index 0000000..3168180 Binary files /dev/null and b/images/pm-settings-integrations.png differ diff --git a/prompt_manager.py b/prompt_manager.py index e8d904a..a1eceff 100644 --- a/prompt_manager.py +++ b/prompt_manager.py @@ -148,6 +148,9 @@ def encode_prompt( parts.append(append_text.strip()) final_text = " ".join(parts) + # Inject LoRA trigger words if integration is enabled + final_text = self._inject_lora_trigger_words(final_text) + # Use the combined text for encoding encoding_text = final_text diff --git a/prompt_manager_base.py b/prompt_manager_base.py index 9a393fc..43a6326 100644 --- a/prompt_manager_base.py +++ b/prompt_manager_base.py @@ -102,6 +102,45 @@ def _save_prompt_to_database( self.logger.error(f"Error saving prompt to database: {e}") return None + def _inject_lora_trigger_words(self, text: str) -> str: + """Append LoRA trigger words if integration is enabled. + + Returns the text unchanged if the integration is disabled or + LoraManager is not installed. + """ + try: + from .py.config import IntegrationConfig + except ImportError: + try: + from py.config import IntegrationConfig + except ImportError: + return text + + if ( + not IntegrationConfig.LORA_MANAGER_ENABLED + or not IntegrationConfig.LORA_TRIGGER_WORDS_ENABLED + ): + return text + + try: + from .py.lora_utils import get_trigger_cache, inject_trigger_words + except ImportError: + try: + from py.lora_utils import get_trigger_cache, inject_trigger_words + except ImportError: + return text + + cache = get_trigger_cache() + + # Lazy-load cache on first use + if not cache.is_loaded and IntegrationConfig.LORA_MANAGER_PATH: + cache.load(IntegrationConfig.LORA_MANAGER_PATH) + + modified, injected = inject_trigger_words(text, cache) + if injected: + self.logger.info(f"Injected trigger words: {', '.join(injected)}") + return modified + def _generate_hash(self, text: str) -> str: """Generate SHA256 hash for the prompt text. diff --git a/prompt_manager_text.py b/prompt_manager_text.py index e10d146..6e0736c 100644 --- a/prompt_manager_text.py +++ b/prompt_manager_text.py @@ -142,6 +142,9 @@ def process_text( parts.append(append_text.strip()) final_text = " ".join(parts) + # Inject LoRA trigger words if integration is enabled + final_text = self._inject_lora_trigger_words(final_text) + # For database storage, save the original main text with metadata about prepend/append storage_text = text diff --git a/py/api/__init__.py b/py/api/__init__.py index 8d7af45..90d4695 100644 --- a/py/api/__init__.py +++ b/py/api/__init__.py @@ -24,6 +24,7 @@ from .admin import AdminRoutesMixin from .logging_routes import LoggingRoutesMixin from .autotag_routes import AutotagRoutesMixin +from .lora_integration import LoraIntegrationMixin try: from ...database.operations import PromptDatabase @@ -99,6 +100,7 @@ class PromptManagerAPI( AdminRoutesMixin, LoggingRoutesMixin, AutotagRoutesMixin, + LoraIntegrationMixin, ): """REST API handler for PromptManager operations and web interface. @@ -372,6 +374,7 @@ async def serve_js_static(request): self._register_admin_routes(routes) self._register_logging_routes(routes) self._register_autotag_routes(routes) + self._register_lora_routes(routes) # Register gzip compression middleware (once) global _gzip_registered diff --git a/py/api/images.py b/py/api/images.py index 0acdb4e..e6b06e1 100644 --- a/py/api/images.py +++ b/py/api/images.py @@ -303,11 +303,31 @@ async def serve_image(self, request): image_path = Path(image["image_path"]).resolve() - # Validate path is within any configured output directory - output_dirs = self._get_all_output_dirs() - if output_dirs: + # Validate path is within any allowed directory + allowed_dirs = list(self._get_all_output_dirs()) + + # Also allow LoRA directories when integration is enabled + try: + from ..config import IntegrationConfig + + if IntegrationConfig.LORA_MANAGER_ENABLED: + from ..lora_utils import ( + find_lora_directories, + get_lora_image_cache_dir, + ) + + lora_dirs = find_lora_directories( + IntegrationConfig.LORA_MANAGER_PATH + ) + allowed_dirs.extend(Path(d) for d in lora_dirs) + allowed_dirs.append(get_lora_image_cache_dir()) + except Exception: + # LoRA integration is optional — skip if unavailable + pass + + if allowed_dirs: allowed = any( - image_path.is_relative_to(d.resolve()) for d in output_dirs + image_path.is_relative_to(d.resolve()) for d in allowed_dirs ) if not allowed: return web.json_response( diff --git a/py/api/lora_integration.py b/py/api/lora_integration.py new file mode 100644 index 0000000..6275bfd --- /dev/null +++ b/py/api/lora_integration.py @@ -0,0 +1,419 @@ +"""LoraManager integration API routes for PromptManager.""" + +import json +import os +from pathlib import Path + +from aiohttp import web + + +class LoraIntegrationMixin: + """Mixin providing LoraManager detection, scanning, and trigger word endpoints.""" + + def _register_lora_routes(self, routes): + @routes.get("/prompt_manager/lora/detect") + async def lora_detect_route(request): + return await self.lora_detect(request) + + @routes.get("/prompt_manager/lora/status") + async def lora_status_route(request): + return await self.lora_status(request) + + @routes.post("/prompt_manager/lora/enable") + async def lora_enable_route(request): + return await self.lora_enable(request) + + @routes.post("/prompt_manager/lora/scan") + async def lora_scan_route(request): + return await self.lora_scan(request) + + @routes.get("/prompt_manager/lora/trigger-words") + async def lora_trigger_words_route(request): + return await self.lora_trigger_words(request) + + @routes.post("/prompt_manager/lora/refresh-cache") + async def lora_refresh_cache_route(request): + return await self.lora_refresh_cache(request) + + # ── Detection ──────────────────────────────────────────────────── + + async def lora_detect(self, request): + """Auto-detect LoraManager installation.""" + try: + from ..lora_utils import detect_lora_manager + + path = await self._run_in_executor(detect_lora_manager) + return web.json_response( + { + "success": True, + "detected": path is not None, + "path": path or "", + } + ) + except Exception as e: + self.logger.error(f"LoraManager detection failed: {e}") + return web.json_response({"success": False, "error": str(e)}, status=500) + + # ── Status ─────────────────────────────────────────────────────── + + async def lora_status(self, request): + """Get current LoraManager integration status.""" + try: + from ..config import IntegrationConfig + from ..lora_utils import detect_lora_manager, get_trigger_cache + + config = IntegrationConfig.get_config()["lora_manager"] + cache = get_trigger_cache() + + # Check if the configured path is still valid + detected_path = await self._run_in_executor( + detect_lora_manager, config.get("path", "") + ) + + return web.json_response( + { + "success": True, + "enabled": config["enabled"], + "path": config["path"], + "trigger_words_enabled": config["trigger_words_enabled"], + "civitai_api_key": config.get("civitai_api_key", ""), + "detected": detected_path is not None, + "detected_path": detected_path or "", + "trigger_cache_loaded": cache.is_loaded, + } + ) + except Exception as e: + self.logger.error(f"LoraManager status check failed: {e}") + return web.json_response({"success": False, "error": str(e)}, status=500) + + # ── Enable / Disable ───────────────────────────────────────────── + + async def lora_enable(self, request): + """Enable or disable LoraManager integration and save to config.json.""" + try: + data = await request.json() + enabled = data.get("enabled", False) + path = data.get("path", "") + trigger_words = data.get("trigger_words_enabled", False) + civitai_key = data.get("civitai_api_key", "") + + from ..config import IntegrationConfig, PromptManagerConfig + from ..lora_utils import detect_lora_manager, get_trigger_cache + + # If enabling, validate the path + if enabled: + resolved = await self._run_in_executor(detect_lora_manager, path) + if not resolved: + return web.json_response( + { + "success": False, + "error": "LoraManager not found at the specified path", + }, + status=400, + ) + path = resolved + + # Update in-memory config + IntegrationConfig.LORA_MANAGER_ENABLED = enabled + IntegrationConfig.LORA_MANAGER_PATH = path + IntegrationConfig.LORA_TRIGGER_WORDS_ENABLED = trigger_words + IntegrationConfig.CIVITAI_API_KEY = civitai_key + + # Persist to config.json + config_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + config_file = os.path.join(config_dir, "config.json") + PromptManagerConfig.save_to_file(config_file) + + # Load trigger word cache if enabling + cache = get_trigger_cache() + if enabled and trigger_words and path: + count = await self._run_in_executor(cache.load, path) + self.logger.info(f"Trigger word cache loaded: {count} LoRAs") + elif not enabled: + cache.clear() + + return web.json_response( + { + "success": True, + "enabled": enabled, + "path": path, + "trigger_words_enabled": trigger_words, + } + ) + except Exception as e: + self.logger.error(f"LoraManager enable/disable failed: {e}") + return web.json_response({"success": False, "error": str(e)}, status=500) + + # ── Scan LoRA example images ───────────────────────────────────── + + async def lora_scan(self, request): + """Scan LoraManager metadata and import LoRA info + preview images. + + Streams progress as SSE, matching the existing scan pattern. + """ + try: + from ..config import IntegrationConfig + from ..lora_utils import ( + download_civitai_images, + find_lora_directories, + get_example_prompt_from_metadata, + get_lora_image_cache_dir, + get_preview_images_from_metadata, + get_trigger_words_from_metadata, + get_model_name_from_metadata, + read_lora_metadata, + ) + + if not IntegrationConfig.LORA_MANAGER_ENABLED: + return web.json_response( + { + "success": False, + "error": "LoraManager integration is not enabled", + }, + status=400, + ) + + lm_path = IntegrationConfig.LORA_MANAGER_PATH + if not lm_path: + return web.json_response( + {"success": False, "error": "LoraManager path not configured"}, + status=400, + ) + + response = web.StreamResponse( + status=200, + reason="OK", + headers={"Content-Type": "text/event-stream"}, + ) + await response.prepare(request) + + async def send_progress(data): + line = f"data: {json.dumps(data)}\n\n" + await response.write(line.encode("utf-8")) + + await send_progress( + { + "type": "progress", + "status": "Clearing previous lora-manager imports...", + "progress": 0, + } + ) + + # Clear previous imports so reimport is always clean + await self._run_in_executor( + self.db.delete_prompts_by_category, "lora-manager" + ) + + await send_progress( + { + "type": "progress", + "status": "Finding LoRA directories...", + "progress": 2, + } + ) + + lora_dirs = await self._run_in_executor(find_lora_directories, lm_path) + + # Collect all metadata files + meta_files = [] + for d in lora_dirs: + dir_path = Path(d) + meta_files.extend(dir_path.rglob("*.metadata.json")) + + total = len(meta_files) + imported = 0 + skipped = 0 + + await send_progress( + { + "type": "progress", + "status": f"Found {total} LoRA metadata files", + "progress": 5, + "total": total, + } + ) + + cache_dir = get_lora_image_cache_dir() + + for i, meta_file in enumerate(meta_files): + metadata = await self._run_in_executor(read_lora_metadata, meta_file) + if not metadata: + skipped += 1 + continue + + model_name = get_model_name_from_metadata(metadata) + trigger_words = get_trigger_words_from_metadata(metadata) + + # Collect all images: local previews + downloaded civitai examples + preview_paths = await self._run_in_executor( + get_preview_images_from_metadata, metadata, meta_file + ) + civitai_paths = await self._run_in_executor( + download_civitai_images, + metadata, + meta_file, + cache_dir, + IntegrationConfig.CIVITAI_API_KEY, + ) + + # Merge, local first, dedup + seen = set(preview_paths) + all_images = list(preview_paths) + for cp in civitai_paths: + if cp not in seen: + all_images.append(cp) + seen.add(cp) + + # Build prompt text: prefer example prompt, then model name + example_prompt = get_example_prompt_from_metadata(metadata) + prompt_text = example_prompt or model_name + + # Build tags + tags = ["lora-manager", f"lora:{model_name}"] + tags.extend(trigger_words) + + # Save to database via existing mechanism + try: + import hashlib + + prompt_hash = hashlib.sha256( + prompt_text.strip().lower().encode("utf-8") + ).hexdigest() + + existing = await self._run_in_executor( + self.db.get_prompt_by_hash, prompt_hash + ) + + if existing: + # Link all images + for pp in all_images: + await self._run_in_executor( + self.db.link_image_to_prompt, + existing["id"], + pp, + ) + skipped += 1 + else: + prompt_id = await self._run_in_executor( + self.db.save_prompt, + prompt_text, + "lora-manager", # category + tags, + None, # rating + None, # notes + prompt_hash, + ) + + if prompt_id: + for pp in all_images: + await self._run_in_executor( + self.db.link_image_to_prompt, + prompt_id, + pp, + ) + imported += 1 + else: + skipped += 1 + + except Exception as e: + self.logger.warning(f"Failed to import LoRA {model_name}: {e}") + skipped += 1 + + # Progress update for every LoRA + progress = int(5 + (90 * (i + 1) / max(total, 1))) + img_count = len(all_images) + status = f"{model_name}" + if img_count: + status += f" ({img_count} images)" + await send_progress( + { + "type": "progress", + "status": status, + "progress": progress, + "processed": i + 1, + "imported": imported, + "skipped": skipped, + } + ) + + await send_progress( + { + "type": "complete", + "progress": 100, + "total": total, + "imported": imported, + "skipped": skipped, + } + ) + + await response.write_eof() + return response + + except Exception as e: + self.logger.error(f"LoRA scan failed: {e}") + return web.json_response({"success": False, "error": str(e)}, status=500) + + # ── Trigger word endpoints ─────────────────────────────────────── + + async def lora_trigger_words(self, request): + """Look up trigger words for a specific LoRA name.""" + try: + from ..config import IntegrationConfig + from ..lora_utils import get_trigger_cache + + if not IntegrationConfig.LORA_MANAGER_ENABLED: + return web.json_response( + {"success": False, "error": "LoraManager integration not enabled"}, + status=400, + ) + + lora_name = request.query.get("name", "") + if not lora_name: + return web.json_response( + {"success": False, "error": "Missing 'name' query parameter"}, + status=400, + ) + + cache = get_trigger_cache() + if not cache.is_loaded: + lm_path = IntegrationConfig.LORA_MANAGER_PATH + if lm_path: + await self._run_in_executor(cache.load, lm_path) + + words = cache.get_trigger_words(lora_name) + return web.json_response( + {"success": True, "lora": lora_name, "trigger_words": words} + ) + except Exception as e: + self.logger.error(f"Trigger word lookup failed: {e}") + return web.json_response({"success": False, "error": str(e)}, status=500) + + async def lora_refresh_cache(self, request): + """Force-refresh the trigger word cache from disk.""" + try: + from ..config import IntegrationConfig + from ..lora_utils import get_trigger_cache + + if not IntegrationConfig.LORA_MANAGER_ENABLED: + return web.json_response( + {"success": False, "error": "LoraManager integration not enabled"}, + status=400, + ) + + lm_path = IntegrationConfig.LORA_MANAGER_PATH + if not lm_path: + return web.json_response( + {"success": False, "error": "LoraManager path not configured"}, + status=400, + ) + + cache = get_trigger_cache() + count = await self._run_in_executor(cache.load, lm_path) + return web.json_response( + {"success": True, "loras_with_trigger_words": count} + ) + except Exception as e: + self.logger.error(f"Trigger cache refresh failed: {e}") + return web.json_response({"success": False, "error": str(e)}, status=500) diff --git a/py/config.py b/py/config.py index bea4b0d..88ef6b9 100644 --- a/py/config.py +++ b/py/config.py @@ -201,6 +201,43 @@ def update_config(cls, new_config: Dict[str, Any]): cls.METADATA_EXTRACTION_TIMEOUT = performance["metadata_extraction_timeout"] +class IntegrationConfig: + """Configuration for third-party extension integrations. + + Manages opt-in integration settings for extensions like LoraManager. + All integrations are disabled by default so PromptManager works standalone. + """ + + # LoraManager integration + LORA_MANAGER_ENABLED = False + LORA_MANAGER_PATH = "" # Auto-detected if empty + LORA_TRIGGER_WORDS_ENABLED = False # Auto-inject trigger words into prompts + CIVITAI_API_KEY = "" # Required to download NSFW example images + + @classmethod + def get_config(cls) -> Dict[str, Any]: + return { + "lora_manager": { + "enabled": cls.LORA_MANAGER_ENABLED, + "path": cls.LORA_MANAGER_PATH, + "trigger_words_enabled": cls.LORA_TRIGGER_WORDS_ENABLED, + "civitai_api_key": cls.CIVITAI_API_KEY, + }, + } + + @classmethod + def update_config(cls, new_config: Dict[str, Any]): + lora = new_config.get("lora_manager", {}) + if "enabled" in lora: + cls.LORA_MANAGER_ENABLED = lora["enabled"] + if "path" in lora: + cls.LORA_MANAGER_PATH = lora["path"] + if "trigger_words_enabled" in lora: + cls.LORA_TRIGGER_WORDS_ENABLED = lora["trigger_words_enabled"] + if "civitai_api_key" in lora: + cls.CIVITAI_API_KEY = lora["civitai_api_key"] + + class PromptManagerConfig: """Main configuration class for PromptManager core functionality. @@ -274,6 +311,7 @@ def get_config(cls) -> Dict[str, Any]: "auto_backup_interval": cls.AUTO_BACKUP_INTERVAL, }, "gallery": GalleryConfig.get_config(), + "integrations": IntegrationConfig.get_config(), } @classmethod @@ -389,6 +427,10 @@ def update_config(cls, new_config: Dict[str, Any]): if "gallery" in new_config: GalleryConfig.update_config(new_config["gallery"]) + # Update integration config + if "integrations" in new_config: + IntegrationConfig.update_config(new_config["integrations"]) + # Load configuration on import try: diff --git a/py/lora_utils.py b/py/lora_utils.py new file mode 100644 index 0000000..62c925f --- /dev/null +++ b/py/lora_utils.py @@ -0,0 +1,529 @@ +"""Utilities for LoraManager integration. + +Provides detection, metadata reading, and trigger word lookup for +ComfyUI-Lora-Manager (https://github.com/willmiao/ComfyUI-Lora-Manager). + +All functions are safe to call when LoraManager is not installed — they +return empty results rather than raising. +""" + +import hashlib +import json +import os +import re +import threading +import urllib.request +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +try: + from ..utils.logging_config import get_logger +except ImportError: + import sys + + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from utils.logging_config import get_logger + +logger = get_logger("prompt_manager.lora_utils") + +# ── LoraManager detection ──────────────────────────────────────────── + + +def find_comfyui_root() -> Optional[Path]: + """Walk upward from this file to find the ComfyUI root (contains main.py). + + Tries both the resolved (real) path and the unresolved path to handle + symlinked custom_nodes installations. + """ + start_paths = [Path(__file__).resolve().parent] + + # If installed via symlink, the unresolved path leads through custom_nodes/ + raw_path = Path(__file__).parent + if raw_path.resolve() != raw_path: + start_paths.append(raw_path) + + # Also try via folder_paths if available (ComfyUI runtime) + try: + import folder_paths + + base = Path(folder_paths.base_path) + if base.is_dir(): + return base + except (ImportError, AttributeError): + # folder_paths unavailable — not running inside ComfyUI runtime + pass + + for start in start_paths: + current = start + for _ in range(10): + if (current / "main.py").exists() and (current / "custom_nodes").exists(): + return current + parent = current.parent + if parent == current: + break + current = parent + + return None + + +def detect_lora_manager(custom_path: str = "") -> Optional[str]: + """Return the absolute path to ComfyUI-Lora-Manager if installed. + + Args: + custom_path: User-provided override path. Checked first. + + Returns: + Absolute path string, or None if not found. + """ + # 1. User override + if custom_path: + p = Path(custom_path) + if p.is_dir() and _looks_like_lora_manager(p): + return str(p.resolve()) + + # 2. Auto-detect via custom_nodes (case-insensitive scan) + root = find_comfyui_root() + if root: + custom_nodes = root / "custom_nodes" + if custom_nodes.is_dir(): + for entry in custom_nodes.iterdir(): + if ( + entry.is_dir() + and "lora" in entry.name.lower() + and "manager" in entry.name.lower() + and _looks_like_lora_manager(entry) + ): + return str(entry.resolve()) + + return None + + +def _looks_like_lora_manager(path: Path) -> bool: + """Heuristic: does this directory look like a LoraManager install?""" + # Must have __init__.py (ComfyUI extension) or README.md + has_init = (path / "__init__.py").exists() + if not has_init: + return False + # Check for characteristic structure: py/ dir, or any .metadata.json nearby + return (path / "py").is_dir() or (path / "lora_manager").is_dir() + + +# ── Metadata reading ───────────────────────────────────────────────── + + +def find_lora_directories(lora_manager_path: str) -> List[str]: + """Find directories that contain LoRA models (with .metadata.json files). + + Searches: ComfyUI models/loras, extra_model_paths.yaml lora dirs, + and the LoraManager extension dir itself. + """ + dirs = set() + lm_path = Path(lora_manager_path) + + root = find_comfyui_root() + if root: + # Default models/loras + models_loras = root / "models" / "loras" + if models_loras.is_dir(): + dirs.add(str(models_loras.resolve())) + + # Extra model paths from ComfyUI config + for extra_dir in _get_extra_lora_paths(root): + if extra_dir.is_dir(): + dirs.add(str(extra_dir.resolve())) + + # Also try folder_paths at runtime (catches all configured paths) + try: + import folder_paths + + for p in folder_paths.get_folder_paths("loras"): + pp = Path(p) + if pp.is_dir(): + dirs.add(str(pp.resolve())) + except (ImportError, AttributeError): + # folder_paths unavailable — not running inside ComfyUI runtime + pass + + # Check for any .metadata.json in the LoraManager dir tree + for meta in lm_path.rglob("*.metadata.json"): + dirs.add(str(meta.parent.resolve())) + + return sorted(dirs) + + +def _get_extra_lora_paths(comfyui_root: Path) -> List[Path]: + """Parse extra_model_paths.yaml for additional LoRA directories.""" + results = [] + for name in ("extra_model_paths.yaml", "extra_model_paths.yml"): + config_file = comfyui_root / name + if not config_file.exists(): + continue + try: + import yaml + + config = yaml.safe_load(config_file.read_text()) + if not isinstance(config, dict): + continue + for section in config.values(): + if not isinstance(section, dict): + continue + base = Path(section.get("base_path", "")) + loras_val = section.get("loras", "") + if not loras_val: + continue + for line in str(loras_val).strip().splitlines(): + line = line.strip() + if not line: + continue + p = Path(line) + if not p.is_absolute(): + p = base / line + if p.is_dir(): + results.append(p) + except Exception as e: + logger.debug(f"Failed to parse {config_file}: {e}") + return results + + +def read_lora_metadata(metadata_path: Path) -> Optional[Dict]: + """Read and parse a single .metadata.json file. + + Returns: + Parsed dict, or None on failure. + """ + try: + with open(metadata_path, "r", encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, OSError) as e: + logger.debug(f"Failed to read {metadata_path}: {e}") + return None + + +def _get_civitai(metadata: Dict) -> Dict: + """Safely get the civitai dict, handling None values.""" + return metadata.get("civitai") or {} + + +def get_trigger_words_from_metadata(metadata: Dict) -> List[str]: + """Extract trigger words from a parsed LoraManager metadata dict.""" + civitai = _get_civitai(metadata) + words = civitai.get("trainedWords", []) + if isinstance(words, list): + return [w.strip() for w in words if isinstance(w, str) and w.strip()] + return [] + + +def get_example_prompt_from_metadata(metadata: Dict) -> Optional[str]: + """Extract an example prompt from civitai image metadata. + + Looks at civitai.images[].meta.prompt for the first available example. + """ + civitai = _get_civitai(metadata) + images = civitai.get("images", []) or [] + for img in images: + if not isinstance(img, dict): + continue + meta = img.get("meta") + if isinstance(meta, dict): + prompt = meta.get("prompt", "") + if isinstance(prompt, str) and prompt.strip(): + return prompt.strip() + return None + + +def get_civitai_image_urls(metadata: Dict) -> List[str]: + """Extract all civitai example image URLs from metadata.""" + civitai = _get_civitai(metadata) + urls = [] + for img in civitai.get("images", []) or []: + if not isinstance(img, dict): + continue + url = img.get("url", "") + if isinstance(url, str) and url.strip(): + urls.append(url.strip()) + return urls + + +def get_model_name_from_metadata(metadata: Dict) -> str: + """Extract the model display name from metadata.""" + name = metadata.get("model_name", "") + if not name: + civitai = _get_civitai(metadata) + model = civitai.get("model") or {} + name = model.get("name", "") + if not name: + name = metadata.get("file_name", "unknown") + return name + + +def get_preview_images_from_metadata(metadata: Dict, metadata_path: Path) -> List[str]: + """Find all local preview/example image paths for a LoRA. + + Returns: + List of absolute path strings to image files. + """ + results = [] + lora_dir = metadata_path.parent + file_name = metadata.get("file_name", "") + if not file_name: + stem = metadata_path.name.replace(".metadata.json", "") + file_name = stem + + base_name = Path(file_name).stem + + # Check standard preview naming conventions + for ext in ( + ".png", + ".jpg", + ".jpeg", + ".webp", + ".preview.png", + ".preview.jpg", + ".preview.jpeg", + ): + candidate = lora_dir / f"{base_name}{ext}" + if candidate.exists(): + results.append(str(candidate.resolve())) + + return results + + +def get_preview_image_from_metadata( + metadata: Dict, metadata_path: Path +) -> Optional[str]: + """Find the first preview image path for a LoRA (backward compat).""" + images = get_preview_images_from_metadata(metadata, metadata_path) + return images[0] if images else None + + +_THUMB_MAX_SIZE = 512 + + +def _download_one(url: str, local_path: Path, api_key: str) -> Optional[str]: + """Download a single image, resize to thumbnail, save as JPEG.""" + try: + headers = {"User-Agent": "ComfyUI-PromptManager/1.0"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req, timeout=10) as resp: + raw = resp.read() + + # Resize to thumbnail to save disk space + from io import BytesIO + + from PIL import Image + + img = Image.open(BytesIO(raw)) + img.thumbnail((_THUMB_MAX_SIZE, _THUMB_MAX_SIZE), Image.LANCZOS) + img = img.convert("RGB") + img.save(str(local_path), "JPEG", quality=85) + + return str(local_path.resolve()) + except Exception as e: + logger.debug(f"Failed to download {url}: {e}") + return None + + +def download_civitai_images( + metadata: Dict, metadata_path: Path, cache_dir: Path, api_key: str = "" +) -> List[str]: + """Download civitai example images to a local cache directory. + + Uses thumbnail URLs (512px) instead of full-size originals, and + downloads in parallel (up to 8 concurrent) for speed. + + Args: + api_key: CivitAI API key for authenticated downloads (NSFW content). + + Returns: + List of absolute paths to downloaded image files. + """ + from concurrent.futures import ThreadPoolExecutor + + civitai = _get_civitai(metadata) + images = civitai.get("images", []) or [] + if not images: + return [] + + file_name = metadata.get("file_name", "") + if not file_name: + file_name = metadata_path.name.replace(".metadata.json", "") + lora_stem = Path(file_name).stem + + lora_cache = cache_dir / lora_stem + lora_cache.mkdir(parents=True, exist_ok=True) + + # Build download tasks + cached = [] + tasks = [] # (url, local_path) + for img in images: + if not isinstance(img, dict): + continue + url = img.get("url", "") + if not isinstance(url, str) or not url.startswith("http"): + continue + + url_hash = hashlib.md5(url.encode()).hexdigest()[:12] + local_path = lora_cache / f"{url_hash}.jpg" + + if local_path.exists(): + cached.append(str(local_path.resolve())) + else: + tasks.append((url, local_path)) + + if not tasks: + return cached + + # Download in parallel + downloaded = [] + with ThreadPoolExecutor(max_workers=8) as pool: + futures = [ + pool.submit(_download_one, url, path, api_key) for url, path in tasks + ] + for fut in futures: + result = fut.result() + if result: + downloaded.append(result) + + return cached + downloaded + + +def get_lora_image_cache_dir() -> Path: + """Get the directory used to cache downloaded LoRA example images.""" + # Store in the extension's own directory + ext_root = Path(__file__).resolve().parent.parent + cache = ext_root / "data" / "lora_images" + cache.mkdir(parents=True, exist_ok=True) + return cache + + +def get_example_images_dir(lora_manager_path: str) -> Optional[str]: + """Find the LoraManager example_images directory.""" + lm_path = Path(lora_manager_path) + + # Direct subdirectory + candidate = lm_path / "example_images" + if candidate.is_dir(): + return str(candidate.resolve()) + + # Search one level in user data dirs + for child in lm_path.iterdir(): + if child.is_dir(): + sub = child / "example_images" + if sub.is_dir(): + return str(sub.resolve()) + + return None + + +# ── Trigger word cache & injection ─────────────────────────────────── + +_LORA_PATTERN = re.compile(r"]+):[^>]+>", re.IGNORECASE) + + +class TriggerWordCache: + """Thread-safe cache mapping LoRA names to their trigger words. + + Built lazily on first access, refreshable on demand. + """ + + def __init__(self): + self._cache: Dict[str, List[str]] = {} + self._lock = threading.Lock() + self._loaded = False + + def load(self, lora_manager_path: str) -> int: + """Scan LoRA metadata files and build the trigger word mapping. + + Returns: + Number of LoRAs with trigger words found. + """ + new_cache: Dict[str, List[str]] = {} + + lora_dirs = find_lora_directories(lora_manager_path) + for lora_dir in lora_dirs: + dir_path = Path(lora_dir) + for meta_file in dir_path.rglob("*.metadata.json"): + metadata = read_lora_metadata(meta_file) + if not metadata: + continue + + words = get_trigger_words_from_metadata(metadata) + if not words: + continue + + # Key by filename stem (what appears in ) + file_name = metadata.get("file_name", "") + if file_name: + stem = Path(file_name).stem + new_cache[stem.lower()] = words + + # Also key by the metadata file stem + meta_stem = meta_file.name.replace(".metadata.json", "") + if meta_stem.lower() not in new_cache: + new_cache[meta_stem.lower()] = words + + with self._lock: + self._cache = new_cache + self._loaded = True + + logger.info( + f"Trigger word cache loaded: {len(new_cache)} LoRAs with trigger words" + ) + return len(new_cache) + + def get_trigger_words(self, lora_name: str) -> List[str]: + """Look up trigger words for a LoRA by name (case-insensitive).""" + with self._lock: + return self._cache.get(lora_name.lower(), []) + + @property + def is_loaded(self) -> bool: + with self._lock: + return self._loaded + + def clear(self): + with self._lock: + self._cache.clear() + self._loaded = False + + +# Module-level singleton +_trigger_cache = TriggerWordCache() + + +def get_trigger_cache() -> TriggerWordCache: + return _trigger_cache + + +def inject_trigger_words(text: str, cache: TriggerWordCache) -> Tuple[str, List[str]]: + """Scan text for tags and append trigger words. + + Args: + text: The prompt text potentially containing lora tags. + cache: Populated TriggerWordCache instance. + + Returns: + Tuple of (modified_text, list_of_injected_words). + If no trigger words found, returns the original text unchanged. + """ + if not cache.is_loaded: + return text, [] + + matches = _LORA_PATTERN.findall(text) + if not matches: + return text, [] + + all_words = [] + for lora_name in matches: + words = cache.get_trigger_words(lora_name) + for w in words: + if w.lower() not in text.lower() and w not in all_words: + all_words.append(w) + + if not all_words: + return text, [] + + injected = ", ".join(all_words) + return f"{text}, {injected}", all_words diff --git a/pyproject.toml b/pyproject.toml index d782192..04cc3b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "promptmanager" description = "A powerful ComfyUI custom node that extends the standard text encoder with persistent prompt storage, advanced search capabilities, and an automatic image gallery system using SQLite." -version = "3.2.0" +version = "3.2.1" license = {file = "LICENSE"} dependencies = ["# Core dependencies for PromptManager", "# Note: Most dependencies are already included with ComfyUI", "# Already included with Python standard library:", "# - sqlite3", "# - hashlib", "# - json", "# - datetime", "# - os", "# - typing", "# - threading", "# - uuid", "# Required for gallery functionality:", "watchdog>=2.1.0 # For file system monitoring", "Pillow>=8.0.0 # For image metadata extraction (usually included with ComfyUI)", "# Optional dependencies for enhanced search functionality:", "# fuzzywuzzy[speedup]>=0.18.0 # For fuzzy string matching (optional)", "# sqlalchemy>=1.4.0 # For advanced ORM features (optional)", "# Development dependencies (optional):", "# pytest>=6.0.0 # For running tests", "# black>=22.0.0 # For code formatting", "# flake8>=4.0.0 # For linting", "# mypy>=0.910 # For type checking"] diff --git a/tests/test_config.py b/tests/test_config.py index 05037ce..8f90862 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -19,7 +19,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from py.config import GalleryConfig, PromptManagerConfig +from py.config import GalleryConfig, IntegrationConfig, PromptManagerConfig class TestGalleryConfig(unittest.TestCase): @@ -207,5 +207,89 @@ def test_saved_file_is_valid_json(self): os.unlink(tmp.name) +class TestIntegrationConfig(unittest.TestCase): + """Test IntegrationConfig for LoRA Manager settings.""" + + def setUp(self): + """Save original values to restore after each test.""" + self._orig = IntegrationConfig.get_config() + + def tearDown(self): + """Restore original config values.""" + IntegrationConfig.update_config(self._orig) + + def test_reset_to_disabled(self): + """Integrations can be fully disabled via update_config.""" + IntegrationConfig.update_config( + { + "lora_manager": { + "enabled": False, + "path": "", + "trigger_words_enabled": False, + "civitai_api_key": "", + } + } + ) + config = IntegrationConfig.get_config() + lora = config["lora_manager"] + self.assertFalse(lora["enabled"]) + self.assertEqual(lora["path"], "") + self.assertFalse(lora["trigger_words_enabled"]) + self.assertEqual(lora["civitai_api_key"], "") + + def test_get_config_structure(self): + config = IntegrationConfig.get_config() + self.assertIn("lora_manager", config) + lora = config["lora_manager"] + self.assertIn("enabled", lora) + self.assertIn("path", lora) + self.assertIn("trigger_words_enabled", lora) + self.assertIn("civitai_api_key", lora) + + def test_update_config_enables(self): + IntegrationConfig.update_config( + { + "lora_manager": { + "enabled": True, + "path": "/some/path", + "trigger_words_enabled": True, + "civitai_api_key": "test-key-123", + } + } + ) + self.assertTrue(IntegrationConfig.LORA_MANAGER_ENABLED) + self.assertEqual(IntegrationConfig.LORA_MANAGER_PATH, "/some/path") + self.assertTrue(IntegrationConfig.LORA_TRIGGER_WORDS_ENABLED) + self.assertEqual(IntegrationConfig.CIVITAI_API_KEY, "test-key-123") + + def test_update_partial(self): + """Updating one field shouldn't affect others.""" + # Reset to known state first + IntegrationConfig.update_config( + {"lora_manager": {"enabled": False, "path": "/known"}} + ) + # Now update only enabled + IntegrationConfig.update_config({"lora_manager": {"enabled": True}}) + self.assertTrue(IntegrationConfig.LORA_MANAGER_ENABLED) + self.assertEqual(IntegrationConfig.LORA_MANAGER_PATH, "/known") + + def test_update_empty_dict_noop(self): + """Updating with empty dict preserves current state.""" + before = IntegrationConfig.get_config() + IntegrationConfig.update_config({}) + after = IntegrationConfig.get_config() + self.assertEqual(before, after) + + def test_roundtrip(self): + """get_config → update_config → get_config should be stable.""" + IntegrationConfig.update_config( + {"lora_manager": {"enabled": True, "path": "/test"}} + ) + config1 = IntegrationConfig.get_config() + IntegrationConfig.update_config(config1) + config2 = IntegrationConfig.get_config() + self.assertEqual(config1, config2) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_lora_database.py b/tests/test_lora_database.py new file mode 100644 index 0000000..946d6d8 --- /dev/null +++ b/tests/test_lora_database.py @@ -0,0 +1,253 @@ +""" +Database tests for LoRA Manager integration and folder filter features. + +Tests delete_prompts_by_category, search_prompts folder filter, +get_prompt_subfolders, and LoRA-specific prompt workflows using +an in-memory SQLite database. +""" + +import os +import sys +import tempfile +import unittest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from database.operations import PromptDatabase +from utils.hashing import generate_prompt_hash + + +class LoraDBTestCase(unittest.TestCase): + """Base class with temp database setup/teardown.""" + + def setUp(self): + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + self.temp_db.close() + self.db = PromptDatabase(self.temp_db.name) + + def tearDown(self): + for suffix in ("", "-wal", "-shm"): + path = self.temp_db.name + suffix + if os.path.exists(path): + os.unlink(path) + + def _save(self, text, category=None, tags=None): + """Save a prompt and return its ID.""" + return self.db.save_prompt( + text=text, + category=category, + tags=tags or [], + prompt_hash=generate_prompt_hash(text), + ) + + def _link_image(self, prompt_id, image_path): + """Link a fake image to a prompt.""" + return self.db.link_image_to_prompt( + prompt_id=str(prompt_id), image_path=image_path + ) + + +# ── delete_prompts_by_category ──────────────────────────────────────── + + +class TestDeleteByCategory(LoraDBTestCase): + """Test delete_prompts_by_category for LoRA reimport cleanup.""" + + def test_deletes_matching_category(self): + self._save("lora prompt 1", category="lora-manager") + self._save("lora prompt 2", category="lora-manager") + self._save("keep this", category="general") + + deleted = self.db.delete_prompts_by_category("lora-manager") + + self.assertEqual(deleted, 2) + results = self.db.search_prompts(category="lora-manager") + self.assertEqual(len(results), 0) + + def test_preserves_other_categories(self): + self._save("keep this", category="general") + self._save("and this", category="portraits") + self.db.delete_prompts_by_category("lora-manager") + + results = self.db.search_prompts() + self.assertEqual(len(results), 2) + + def test_returns_zero_when_none_match(self): + self._save("no match", category="general") + deleted = self.db.delete_prompts_by_category("lora-manager") + self.assertEqual(deleted, 0) + + def test_cascades_to_images(self): + pid = self._save("lora with image", category="lora-manager") + self._link_image(pid, "/fake/path/image.jpg") + + # Verify image is linked + images = self.db.get_prompt_images(pid) + self.assertEqual(len(images), 1) + + self.db.delete_prompts_by_category("lora-manager") + + # Prompt gone + results = self.db.search_prompts(category="lora-manager") + self.assertEqual(len(results), 0) + + def test_empty_category_string(self): + self._save("test", category="general") + deleted = self.db.delete_prompts_by_category("") + self.assertEqual(deleted, 0) + + +# ── Folder filter (search_prompts with folder param) ────────────────── + + +class TestFolderFilter(LoraDBTestCase): + """Test search_prompts folder parameter for subfolder filtering.""" + + def _setup_prompts_with_images(self): + """Create prompts linked to images in different directories.""" + pid1 = self._save("landscape prompt", category="nature") + self._link_image(pid1, "/output/landscapes/sunset.png") + + pid2 = self._save("portrait prompt", category="portraits") + self._link_image(pid2, "/output/portraits/face.png") + + pid3 = self._save("another landscape", category="nature") + self._link_image(pid3, "/output/landscapes/mountain.png") + + return pid1, pid2, pid3 + + def test_filter_by_folder(self): + self._setup_prompts_with_images() + results = self.db.search_prompts(folder="landscapes") + self.assertEqual(len(results), 2) + texts = {r["text"] for r in results} + self.assertEqual(texts, {"landscape prompt", "another landscape"}) + + def test_filter_different_folder(self): + self._setup_prompts_with_images() + results = self.db.search_prompts(folder="portraits") + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["text"], "portrait prompt") + + def test_no_match_returns_empty(self): + self._setup_prompts_with_images() + results = self.db.search_prompts(folder="nonexistent") + self.assertEqual(len(results), 0) + + def test_no_folder_returns_all(self): + self._setup_prompts_with_images() + results = self.db.search_prompts() + self.assertGreaterEqual(len(results), 3) + + def test_folder_with_category_filter(self): + self._setup_prompts_with_images() + results = self.db.search_prompts(folder="landscapes", category="nature") + self.assertEqual(len(results), 2) + + +# ── get_prompt_subfolders ───────────────────────────────────────────── + + +class TestGetPromptSubfolders(LoraDBTestCase): + """Test get_prompt_subfolders — extracts unique folder names from images.""" + + def test_extracts_subfolders(self): + pid1 = self._save("prompt 1") + self._link_image(pid1, "/output/folder_a/img1.png") + + pid2 = self._save("prompt 2") + self._link_image(pid2, "/output/folder_b/img2.png") + + folders = self.db.get_prompt_subfolders() + self.assertIsInstance(folders, list) + self.assertGreaterEqual(len(folders), 2) + + def test_deduplicates(self): + pid1 = self._save("prompt 1") + self._link_image(pid1, "/output/same_folder/img1.png") + + pid2 = self._save("prompt 2") + self._link_image(pid2, "/output/same_folder/img2.png") + + folders = self.db.get_prompt_subfolders() + # Count occurrences of the folder — should appear once + matches = [f for f in folders if "same_folder" in f] + self.assertEqual(len(matches), 1) + + def test_empty_database(self): + folders = self.db.get_prompt_subfolders() + self.assertEqual(folders, []) + + def test_returns_sorted(self): + for i, name in enumerate(["charlie", "alpha", "bravo"]): + pid = self._save(f"prompt {i}") + self._link_image(pid, f"/output/{name}/img.png") + + folders = self.db.get_prompt_subfolders() + self.assertEqual(folders, sorted(folders)) + + def test_with_root_dirs(self): + pid = self._save("prompt") + self._link_image(pid, "/output/sub/deep/img.png") + + folders = self.db.get_prompt_subfolders(root_dirs=["/output"]) + self.assertIsInstance(folders, list) + self.assertGreater(len(folders), 0) + + +# ── LoRA prompt workflow ────────────────────────────────────────────── + + +class TestLoraPromptWorkflow(LoraDBTestCase): + """Test the full LoRA import workflow at the database layer.""" + + def test_save_lora_prompt_with_tags(self): + """Simulate what lora_scan does: save prompt with lora-manager tags.""" + pid = self._save( + text="1girl, detailed face, anime style", + category="lora-manager", + tags=["lora-manager", "lora:my_lora", "trigger1"], + ) + prompt = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt["category"], "lora-manager") + self.assertIn("lora-manager", prompt["tags"]) + + def test_reimport_clears_and_recreates(self): + """Simulate reimport: delete old, create new.""" + # First import + pid1 = self._save("old lora prompt", category="lora-manager") + self._link_image(pid1, "/cache/old.jpg") + + # Reimport + self.db.delete_prompts_by_category("lora-manager") + + # Second import + pid2 = self._save("new lora prompt", category="lora-manager") + self._link_image(pid2, "/cache/new.jpg") + + results = self.db.search_prompts(category="lora-manager") + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["text"], "new lora prompt") + + def test_hash_dedup_prevents_duplicates(self): + """Verify hash-based dedup works for LoRA prompts.""" + text = "duplicate lora prompt" + h = generate_prompt_hash(text) + + self._save(text, category="lora-manager") + existing = self.db.get_prompt_by_hash(h) + self.assertIsNotNone(existing) + + def test_link_multiple_images_to_lora_prompt(self): + """LoRA prompts can have multiple preview images.""" + pid = self._save("multi-image lora", category="lora-manager") + self._link_image(pid, "/cache/lora/img1.jpg") + self._link_image(pid, "/cache/lora/img2.jpg") + self._link_image(pid, "/cache/lora/img3.jpg") + + images = self.db.get_prompt_images(pid) + self.assertEqual(len(images), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_lora_utils.py b/tests/test_lora_utils.py new file mode 100644 index 0000000..3f355c3 --- /dev/null +++ b/tests/test_lora_utils.py @@ -0,0 +1,316 @@ +""" +Unit tests for LoRA Manager integration utilities. + +Tests metadata parsing, trigger word extraction, image URL extraction, +directory detection, TriggerWordCache, and image download logic. +""" + +import json +import os +import sys +import tempfile +import threading +import unittest +from pathlib import Path +from unittest.mock import patch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from py.lora_utils import ( + TriggerWordCache, + get_civitai_image_urls, + get_example_prompt_from_metadata, + get_lora_image_cache_dir, + get_trigger_words_from_metadata, + read_lora_metadata, +) + +# ── Sample metadata fixtures ─────────────────────────────────────────── + + +def _make_metadata( + trained_words=None, + images=None, + model_name="test_lora", + file_name="test.safetensors", +): + """Build a realistic LoRA metadata dict for testing.""" + meta = {"file_name": file_name} + civitai = {} + if trained_words is not None: + civitai["trainedWords"] = trained_words + if images is not None: + civitai["images"] = images + if model_name: + civitai["model"] = {"name": model_name} + if civitai: + meta["civitai"] = civitai + return meta + + +# ── Pure function tests (no mocking) ────────────────────────────────── + + +class TestGetTriggerWords(unittest.TestCase): + """Test get_trigger_words_from_metadata — pure dict extraction.""" + + def test_extracts_words(self): + meta = _make_metadata(trained_words=["word1", "word2", "word3"]) + self.assertEqual( + get_trigger_words_from_metadata(meta), ["word1", "word2", "word3"] + ) + + def test_strips_whitespace(self): + meta = _make_metadata(trained_words=[" padded ", "\ttabbed\t"]) + self.assertEqual(get_trigger_words_from_metadata(meta), ["padded", "tabbed"]) + + def test_filters_empty_strings(self): + meta = _make_metadata(trained_words=["valid", "", " ", "also_valid"]) + self.assertEqual(get_trigger_words_from_metadata(meta), ["valid", "also_valid"]) + + def test_no_civitai_key(self): + self.assertEqual(get_trigger_words_from_metadata({}), []) + + def test_no_trained_words(self): + meta = _make_metadata() + self.assertEqual(get_trigger_words_from_metadata(meta), []) + + def test_trained_words_not_list(self): + meta = {"civitai": {"trainedWords": "not a list"}} + self.assertEqual(get_trigger_words_from_metadata(meta), []) + + def test_non_string_items_filtered(self): + meta = _make_metadata(trained_words=["valid", 123, None, "also_valid"]) + self.assertEqual(get_trigger_words_from_metadata(meta), ["valid", "also_valid"]) + + +class TestGetExamplePrompt(unittest.TestCase): + """Test get_example_prompt_from_metadata — extracts first usable prompt.""" + + def test_extracts_first_prompt(self): + images = [ + {"meta": {"prompt": "a beautiful landscape"}}, + {"meta": {"prompt": "second prompt"}}, + ] + meta = _make_metadata(images=images) + self.assertEqual( + get_example_prompt_from_metadata(meta), "a beautiful landscape" + ) + + def test_skips_empty_prompts(self): + images = [ + {"meta": {"prompt": ""}}, + {"meta": {"prompt": " "}}, + {"meta": {"prompt": "valid prompt"}}, + ] + meta = _make_metadata(images=images) + self.assertEqual(get_example_prompt_from_metadata(meta), "valid prompt") + + def test_no_images(self): + meta = _make_metadata(images=[]) + self.assertIsNone(get_example_prompt_from_metadata(meta)) + + def test_no_civitai(self): + self.assertIsNone(get_example_prompt_from_metadata({})) + + def test_images_without_meta(self): + images = [{"url": "http://example.com/img.jpg"}] + meta = _make_metadata(images=images) + self.assertIsNone(get_example_prompt_from_metadata(meta)) + + def test_meta_without_prompt(self): + images = [{"meta": {"seed": 12345}}] + meta = _make_metadata(images=images) + self.assertIsNone(get_example_prompt_from_metadata(meta)) + + def test_non_dict_images_skipped(self): + images = ["not a dict", None, {"meta": {"prompt": "found it"}}] + meta = _make_metadata(images=images) + self.assertEqual(get_example_prompt_from_metadata(meta), "found it") + + def test_non_string_prompt_skipped(self): + images = [{"meta": {"prompt": 12345}}, {"meta": {"prompt": "real prompt"}}] + meta = _make_metadata(images=images) + self.assertEqual(get_example_prompt_from_metadata(meta), "real prompt") + + +class TestGetCivitaiImageUrls(unittest.TestCase): + """Test get_civitai_image_urls — extracts image URLs from metadata.""" + + def test_extracts_urls(self): + images = [ + {"url": "https://civitai.com/img1.jpg"}, + {"url": "https://civitai.com/img2.jpg"}, + ] + meta = _make_metadata(images=images) + urls = get_civitai_image_urls(meta) + self.assertEqual(len(urls), 2) + self.assertIn("https://civitai.com/img1.jpg", urls) + + def test_filters_empty_urls(self): + images = [{"url": ""}, {"url": "https://civitai.com/valid.jpg"}] + meta = _make_metadata(images=images) + urls = get_civitai_image_urls(meta) + self.assertEqual(urls, ["https://civitai.com/valid.jpg"]) + + def test_no_images(self): + meta = _make_metadata(images=[]) + self.assertEqual(get_civitai_image_urls(meta), []) + + def test_no_civitai(self): + self.assertEqual(get_civitai_image_urls({}), []) + + def test_images_without_url_key(self): + images = [{"id": 1}, {"url": "https://civitai.com/valid.jpg"}] + meta = _make_metadata(images=images) + urls = get_civitai_image_urls(meta) + self.assertEqual(urls, ["https://civitai.com/valid.jpg"]) + + def test_non_dict_images_skipped(self): + images = [None, "bad", {"url": "https://civitai.com/valid.jpg"}] + meta = _make_metadata(images=images) + urls = get_civitai_image_urls(meta) + self.assertEqual(urls, ["https://civitai.com/valid.jpg"]) + + +# ── Filesystem-dependent tests ──────────────────────────────────────── + + +class TestReadLoraMetadata(unittest.TestCase): + """Test read_lora_metadata — file I/O with JSON parsing.""" + + def test_reads_valid_json(self): + with tempfile.NamedTemporaryFile( + mode="w", suffix=".metadata.json", delete=False + ) as f: + json.dump({"civitai": {"trainedWords": ["test"]}}, f) + f.flush() + path = Path(f.name) + try: + result = read_lora_metadata(path) + self.assertIsNotNone(result) + self.assertEqual(result["civitai"]["trainedWords"], ["test"]) + finally: + os.unlink(path) + + def test_returns_none_for_invalid_json(self): + with tempfile.NamedTemporaryFile( + mode="w", suffix=".metadata.json", delete=False + ) as f: + f.write("not valid json {{{") + f.flush() + path = Path(f.name) + try: + result = read_lora_metadata(path) + self.assertIsNone(result) + finally: + os.unlink(path) + + def test_returns_none_for_missing_file(self): + result = read_lora_metadata(Path("/nonexistent/file.metadata.json")) + self.assertIsNone(result) + + +class TestGetLoraImageCacheDir(unittest.TestCase): + """Test get_lora_image_cache_dir — returns and creates cache path.""" + + def test_returns_path(self): + cache_dir = get_lora_image_cache_dir() + self.assertIsInstance(cache_dir, Path) + self.assertTrue(str(cache_dir).endswith("data/lora_images")) + + def test_directory_exists(self): + cache_dir = get_lora_image_cache_dir() + self.assertTrue(cache_dir.is_dir()) + + +# ── TriggerWordCache tests ──────────────────────────────────────────── + + +class TestTriggerWordCache(unittest.TestCase): + """Test TriggerWordCache — thread-safe trigger word lookup.""" + + def setUp(self): + self.cache = TriggerWordCache() + + def test_initial_state(self): + self.assertFalse(self.cache.is_loaded) + self.assertEqual(self.cache.get_trigger_words("anything"), []) + + def test_load_from_temp_directory(self): + """Create temp metadata files and verify cache loads them.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a metadata file + meta = { + "file_name": "my_lora.safetensors", + "civitai": {"trainedWords": ["trigger1", "trigger2"]}, + } + meta_path = Path(tmpdir) / "my_lora.safetensors.metadata.json" + meta_path.write_text(json.dumps(meta)) + + # Patch find_lora_directories to return our temp dir + with patch("py.lora_utils.find_lora_directories", return_value=[tmpdir]): + count = self.cache.load(tmpdir) + + self.assertTrue(self.cache.is_loaded) + # Cache keys by both file_name stem and metadata filename stem + self.assertGreaterEqual(count, 1) + self.assertEqual( + self.cache.get_trigger_words("my_lora"), ["trigger1", "trigger2"] + ) + + def test_case_insensitive_lookup(self): + with tempfile.TemporaryDirectory() as tmpdir: + meta = { + "file_name": "MyLoRA.safetensors", + "civitai": {"trainedWords": ["word1"]}, + } + (Path(tmpdir) / "MyLoRA.safetensors.metadata.json").write_text( + json.dumps(meta) + ) + + with patch("py.lora_utils.find_lora_directories", return_value=[tmpdir]): + self.cache.load(tmpdir) + + self.assertEqual(self.cache.get_trigger_words("mylora"), ["word1"]) + self.assertEqual(self.cache.get_trigger_words("MYLORA"), ["word1"]) + + def test_clear(self): + # Manually set cache state + self.cache._cache = {"test": ["word"]} + self.cache._loaded = True + + self.cache.clear() + self.assertFalse(self.cache.is_loaded) + self.assertEqual(self.cache.get_trigger_words("test"), []) + + def test_unknown_lora_returns_empty(self): + self.cache._cache = {"known": ["word"]} + self.cache._loaded = True + self.assertEqual(self.cache.get_trigger_words("unknown"), []) + + def test_thread_safety(self): + """Verify concurrent access doesn't raise.""" + self.cache._cache = {"lora": ["word"]} + self.cache._loaded = True + + errors = [] + + def reader(): + try: + for _ in range(100): + self.cache.get_trigger_words("lora") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=reader) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(errors, []) + + +if __name__ == "__main__": + unittest.main() diff --git a/web/admin.html b/web/admin.html index 05282d4..ecccce0 100644 --- a/web/admin.html +++ b/web/admin.html @@ -406,6 +406,56 @@

Choose how the Web UI opens from ComfyUI nodes.

+ + +
+

Integrations

+ +
+
+
+ + checking... +
+ +
+ + +
+
@@ -421,6 +471,32 @@

+ + +