diff --git a/.gitignore b/.gitignore
index bcb744705..eb9383833 100644
--- a/.gitignore
+++ b/.gitignore
@@ -42,3 +42,20 @@ loras_i2v/
.env
_test*
+
+*.db
+
+# Test outputs and generated test files
+test_outputs/
+test_results_comparison/
+tests/
+
+# Task database files
+tasks.db-shm
+tasks.db-wal
+
+# venv
+venv/*
+venv312/*
+venv/
+venv312/
\ No newline at end of file
diff --git a/README.md b/README.md
index b060e517b..92941d5e5 100644
--- a/README.md
+++ b/README.md
@@ -1,506 +1,1371 @@
-# WanGP Headless Processing
+# Headless-Wan2GP – Minimal Quick Start
-This document describes the headless processing feature for WanGP, enabling automated video generation by monitoring a task queue. This queue can be a local SQLite database or a centralized PostgreSQL database (e.g., managed by Supabase). Credit for the original Wan2GP repository to [deepbeepmeep](https://github.com/deepbeepmeep).
+This guide shows the two supported ways to run the headless video-generation worker.
-## Overview
+---
-The `headless.py` script is the core worker process that allows users to run WanGP without the Gradio web interface. It continuously polls a task queue for video generation jobs. When a new task is found, it processes it using the `wgp.py` engine or other configured handlers (like ComfyUI).
+## 1 Run Locally (SQLite)
-### Orchestrating multi-step workflows with `steerable_motion.py`
+The worker keeps its task queue in a local `tasks.db` SQLite file and saves all videos under `./outputs`.
-Located *outside* the `Wan2GP/` directory, `steerable_motion.py` is a command-line utility designed to simplify the creation of complex, multi-step video generation workflows by enqueuing a series of coordinated tasks for `headless.py` to process. Instead of manually inserting multiple intricate JSON rows into the database, you can use `steerable_motion.py` to define high-level goals.
-
-It currently provides two main sub-commands:
+```bash
+# 1) Grab the code and enter the folder
+git clone https://github.com/peteromallet/Headless-Wan2GP.git
+cd Headless-Wan2GP
-| Sub-command | Purpose | Typical use-case |
-| ----------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- |
-| `travel_between_images` | Generates a video that smoothly "travels" between a list of anchor images. It enqueues an orchestrator task, which `headless.py` then uses to manage individual segment generations and final stitching. | Timelapse-like transitions between concept art frames, architectural visualizations. |
-| `different_pose` | Takes a single reference image plus a target prompt and produces a new video of that character in a *different pose*. Internally, it queues tasks for OpenPose extraction and guided video generation. | Turning a static portrait into an animated motion, gesture, or expression change. |
+# 2) Initialize the Wan2GP submodule
+git submodule init
+git submodule update
-All common flags such as `--resolution`, `--seed`, `--debug`, `--use_causvid_lora`, `--execution_engine` (to choose between `wgp` or `comfyui` for generation steps) are accepted by `steerable_motion.py` and forwarded appropriately within the task payloads it creates. The script also ensures the local SQLite database (default: `tasks.db`) and the necessary `tasks` table exist before queuing work.
+# 3) Create a Python ≥3.10 virtual environment (Python 3.12 recommended)
+python3.12 -m venv venv
+source venv/bin/activate
-See `python steerable_motion.py --help` for the full argument list.
+# 4) Install PyTorch (pick the wheel that matches your hardware)
+pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
-**Key Features of the Headless System:**
+# 5) Project requirements
+pip install -r Wan2GP/requirements.txt
+pip install -r requirements.txt
-* **Dual Database Backend:**
- * **SQLite (Default):** Easy to set up, stores tasks in a local `tasks.db` file. Ideal for single-machine use.
- * **PostgreSQL/Supabase:** Allows for a more robust, centralized task queue when configured via a `.env` file. Suitable for multi-worker setups or when a managed database is preferred.
-* **Automated Output Handling:**
- * **SQLite Mode:** Videos are saved locally. For `steerable_motion.py` tasks, intermediate files go into subdirectories within `--output_dir`, and the final video is placed directly in `--output_dir`.
- * **Supabase Mode:** Videos are uploaded to a configured Supabase Storage bucket. The public URL is stored in the database.
- * **Persistent Task Queue:** Tasks are not lost if the server restarts.
- * **Configurable Polling:** Set how often `headless.py` checks for new tasks.
- * **Debug Mode:** Verbose logging for troubleshooting (`--debug` flag for both `headless.py` and `steerable_motion.py`).
- * **Global `wgp.py` Overrides:** Configure `wgp.py` settings at `headless.py` server startup.
+# 6) Create LoRA directories (required for tasks that use LoRAs)
+mkdir -p loras loras_i2v
-## Quick Start / Basic Setup
+# 7) Start the worker – polls the local SQLite DB every 10 s
+python headless.py --main-output-dir ./outputs
+```
-For a quick setup and to run the server (defaults to SQLite mode):
+---
-```bash
-git clone https://github.com/peteromallet/Headless-Wan2GP /workspace/Wan2GP && \\
-cd /workspace/Wan2GP && \\
-apt-get update && apt-get install -y python3.10-venv ffmpeg && \\
-python3.10 -m venv venv && \\
-source venv/bin/activate && \\
-pip install --no-cache-dir torch==2.6.0 torchvision torchaudio -f https://download.pytorch.org/whl/cu124 && \\
-pip install --no-cache-dir -r Wan2GP/requirements.txt && \\
-pip install --no-cache-dir -r requirements.txt
-python Wan2GP/headless.py
-```
+## 2 Run with Supabase (PostgreSQL + Storage)
-Once `headless.py` is running, you can open another terminal to queue tasks using `steerable_motion.py` (see examples below) or by adding tasks directly to the database.
+Use Supabase when you need multiple workers or public download URLs for finished videos.
-
-Detailed Configuration and Supabase Setup
+1. Create a Supabase project and a **public** storage bucket (e.g. `videos`).
+2. In the SQL editor run the helper functions found in `SUPABASE_SETUP.md` to create the `tasks` table and RPC utilities.
+3. Add a `.env` file at the repo root:
-1. **Clone the Repository (if not done by Quick Start):**
- ```bash
- git clone https://github.com/peteromallet/Headless-Wan2GP # Or your fork
- cd Headless-Wan2GP
- # Note: The main scripts like steerable_motion.py are at the root.
- # Wan2GP-specific code is under the Wan2GP/ subdirectory.
+ ```env
+ DB_TYPE=supabase
+ SUPABASE_URL=https://.supabase.co
+ SUPABASE_SERVICE_KEY=
+ SUPABASE_VIDEO_BUCKET=videos
+ POSTGRES_TABLE_NAME=tasks # optional (defaults to "tasks")
```
+4. Install dependencies as in the local setup, then run:
-2. **Create a Virtual Environment (if not done by Quick Start):**
```bash
- # Ensure python3.10-venv or equivalent is installed
- # apt-get update && apt-get install -y python3.10-venv
- python3.10 -m venv venv
- source venv/bin/activate
+ python headless.py --main-output-dir ./outputs
```
-3. **Install PyTorch (if not done by Quick Start):**
- Ensure you install a version of PyTorch compatible with your CUDA version (if using GPU).
- ```bash
- # Example for CUDA 12.4 (adjust as needed)
- pip install --no-cache-dir torch==2.6.0 torchvision torchaudio -f https://download.pytorch.org/whl/cu124
- ```
+### Auth: service-role key vs. personal token
-4. **Install Python Dependencies (if not done by Quick Start):**
- There are two main `requirements.txt` files:
- * `Wan2GP/requirements.txt`: For the core Wan2GP library.
- * `requirements.txt` (at the root): For `steerable_motion.py` and `headless.py` (includes `supabase`, `python-dotenv`, etc.).
- ```bash
- pip install --no-cache-dir -r Wan2GP/requirements.txt
- pip install --no-cache-dir -r requirements.txt
- ```
+The worker needs a Supabase token to read/write the `tasks` table and upload the final video file.
-5. **Environment Configuration (`.env` file):**
- Create a `.env` file in the root directory of the `Headless-Wan2GP` project (i.e., next to `steerable_motion.py` and `headless.py`).
- ```env
- # --- Database Configuration ---
- # DB_TYPE: "sqlite" (default) or "supabase" (for PostgreSQL via Supabase)
- DB_TYPE=sqlite
-
- # If DB_TYPE=sqlite, you can optionally specify a custom path for the SQLite DB file.
- # If not set, defaults to "tasks.db" in the current working directory.
- # SQLITE_DB_PATH_ENV="path/to/your/custom_tasks.db"
-
- # If DB_TYPE=supabase:
- # POSTGRES_TABLE_NAME: Desired table name for tasks, used in RPC calls. Default: "tasks"
- POSTGRES_TABLE_NAME="tasks"
- SUPABASE_URL="https://your-project-ref.supabase.co"
- SUPABASE_SERVICE_KEY="your_supabase_service_role_key" # Keep this secret!
- SUPABASE_VIDEO_BUCKET="videos" # Your Supabase storage bucket name
-
- # --- ComfyUI Configuration (Optional) ---
- # If using the "comfyui" execution_engine, headless.py needs to know where ComfyUI saves outputs.
- # COMFYUI_OUTPUT_PATH="/path/to/your/ComfyUI/output"
- ```
- * `headless.py` and `steerable_motion.py` will load these variables.
- * If `DB_TYPE=supabase` is set but `SUPABASE_URL` or `SUPABASE_SERVICE_KEY` are missing, scripts will warn and may fall back to SQLite or fail.
-
-6. **Supabase Setup (If using `DB_TYPE=supabase`):**
- * **SQL Functions (CRITICAL):** You MUST create specific SQL functions in your Supabase PostgreSQL database. `headless.py` relies on these. Go to your Supabase Dashboard -> SQL Editor -> "New query" and execute the following definitions:
-
- **Function 1: `func_initialize_tasks_table`**
- ```sql
- CREATE OR REPLACE FUNCTION func_initialize_tasks_table(p_table_name TEXT)
- RETURNS VOID AS $$
- BEGIN
- EXECUTE format('
- CREATE TABLE IF NOT EXISTS %I (
- id BIGSERIAL PRIMARY KEY,
- task_id TEXT UNIQUE NOT NULL,
- params JSONB NOT NULL,
- task_type TEXT NOT NULL, -- Added NOT NULL constraint
- status TEXT NOT NULL DEFAULT ''Queued'',
- worker_id TEXT NULL,
- output_location TEXT NULL,
- created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
- );', p_table_name);
-
- EXECUTE format('
- CREATE INDEX IF NOT EXISTS idx_%s_status_created_at ON %I (status, created_at);
- ', p_table_name, p_table_name);
-
- EXECUTE format('
- CREATE UNIQUE INDEX IF NOT EXISTS idx_%s_task_id ON %I (task_id);
- ', p_table_name, p_table_name);
-
- -- Index for task_type, useful for querying specific task types
- EXECUTE format('
- CREATE INDEX IF NOT EXISTS idx_%s_task_type ON %I (task_type);
- ', p_table_name, p_table_name);
-
- -- Index for orchestrator_run_id within params, useful for travel tasks
- -- Note: This creates an index on a JSONB path. Ensure your Postgres version supports this efficiently.
- EXECUTE format('
- CREATE INDEX IF NOT EXISTS idx_%s_params_orchestrator_run_id ON %I ((params->>''orchestrator_run_id''));
- ', p_table_name, p_table_name);
-
- END;
- $$ LANGUAGE plpgsql SECURITY DEFINER;
- ```
-
- **Function 2: `func_claim_task`** (Updated to return `task_type_out`)
- ```sql
- CREATE OR REPLACE FUNCTION func_claim_task(p_table_name TEXT, p_worker_id TEXT)
- RETURNS TABLE(task_id_out TEXT, params_out JSONB, task_type_out TEXT) AS $$
- DECLARE
- v_task_id TEXT;
- v_params JSONB;
- v_task_type TEXT;
- BEGIN
- EXECUTE format('
- WITH selected_task AS (
- SELECT id, task_id, params, task_type -- Include task_type
- FROM %I
- WHERE status = ''Queued''
- ORDER BY created_at ASC
- LIMIT 1
- FOR UPDATE SKIP LOCKED
- ), updated_task AS (
- UPDATE %I
- SET
- status = ''In Progress'',
- worker_id = $1,
- updated_at = CURRENT_TIMESTAMP
- WHERE id = (SELECT st.id FROM selected_task st)
- RETURNING task_id, params, task_type -- Return task_type
- )
- SELECT ut.task_id, ut.params, ut.task_type FROM updated_task ut LIMIT 1',
- p_table_name, p_table_name
- )
- INTO v_task_id, v_params, v_task_type -- Store task_type
- USING p_worker_id;
-
- IF v_task_id IS NOT NULL THEN
- RETURN QUERY SELECT v_task_id, v_params, v_task_type;
- ELSE
- RETURN QUERY SELECT NULL::TEXT, NULL::JSONB, NULL::TEXT WHERE FALSE;
- END IF;
- END;
- $$ LANGUAGE plpgsql SECURITY DEFINER;
- ```
-
- **Function 3: `func_update_task_status`** (No changes needed from previous version if it handled `p_output_location` correctly)
- ```sql
- CREATE OR REPLACE FUNCTION func_update_task_status(
- p_table_name TEXT,
- p_task_id TEXT,
- p_status TEXT,
- p_output_location TEXT DEFAULT NULL
- )
- RETURNS VOID AS $$
- BEGIN
- IF p_status = 'Complete' AND p_output_location IS NOT NULL THEN
- EXECUTE format('
- UPDATE %I
- SET status = $1, updated_at = CURRENT_TIMESTAMP, output_location = $2
- WHERE task_id = $3;',
- p_table_name)
- USING p_status, p_output_location, p_task_id;
- ELSE
- EXECUTE format('
- UPDATE %I
- SET status = $1, updated_at = CURRENT_TIMESTAMP
- WHERE task_id = $2;',
- p_table_name)
- USING p_status, p_task_id;
- END IF;
- END;
- $$ LANGUAGE plpgsql SECURITY DEFINER;
- ```
-
- **Function 4: `func_migrate_tasks_for_task_type`** (New, for schema migration)
- ```sql
- CREATE OR REPLACE FUNCTION func_migrate_tasks_for_task_type(p_table_name TEXT)
- RETURNS TEXT AS $$
- DECLARE
- col_exists BOOLEAN;
- migrated_count INTEGER := 0;
- defaulted_count INTEGER := 0;
- BEGIN
- -- Check if task_type column exists
- SELECT EXISTS (
- SELECT 1
- FROM information_schema.columns
- WHERE table_schema = current_schema() -- or your specific schema
- AND table_name = p_table_name
- AND column_name = 'task_type'
- ) INTO col_exists;
-
- IF NOT col_exists THEN
- RAISE NOTICE 'Column task_type does not exist in %. Adding column.', p_table_name;
- EXECUTE format('ALTER TABLE %I ADD COLUMN task_type TEXT', p_table_name);
- ELSE
- RAISE NOTICE 'Column task_type already exists in %.', p_table_name;
- END IF;
-
- -- Populate task_type from params where task_type is NULL
- RAISE NOTICE 'Attempting to populate task_type from params->>''task_type'' for NULL rows in %s...', p_table_name;
- EXECUTE format('
- WITH updated_rows AS (
- UPDATE %I
- SET task_type = params->>''task_type''
- WHERE task_type IS NULL AND params->>''task_type'' IS NOT NULL
- RETURNING 1
- )
- SELECT count(*) FROM updated_rows;',
- p_table_name
- ) INTO migrated_count;
- RAISE NOTICE 'Populated task_type for % rows from params.', migrated_count;
-
- -- Optionally, default remaining NULL task_types for old standard tasks
- RAISE NOTICE 'Attempting to default remaining NULL task_type to ''standard_wgp_task'' in %s...', p_table_name;
- EXECUTE format('
- WITH defaulted_rows AS (
- UPDATE %I
- SET task_type = ''standard_wgp_task''
- WHERE task_type IS NULL
- RETURNING 1
- )
- SELECT count(*) FROM defaulted_rows;',
- p_table_name
- ) INTO defaulted_count;
- RAISE NOTICE 'Defaulted task_type for % rows.', defaulted_count;
-
- -- Ensure NOT NULL constraint if this function is also for initial setup
- -- However, func_initialize_tasks_table should handle this for new tables.
- -- If altering, ensure data is clean first or handle potential errors.
- -- EXECUTE format('ALTER TABLE %I ALTER COLUMN task_type SET NOT NULL;', p_table_name);
- -- RAISE NOTICE 'Applied NOT NULL constraint to task_type in %.', p_table_name;
-
- RETURN 'Migration for task_type completed. Migrated from params: ' || migrated_count || ', Defaulted NULLs: ' || defaulted_count;
- EXCEPTION
- WHEN OTHERS THEN
- RAISE WARNING 'Error during func_migrate_tasks_for_task_type for table %: %', p_table_name, SQLERRM;
- RETURN 'Migration for task_type failed. Check database logs.';
- END;
- $$ LANGUAGE plpgsql SECURITY DEFINER;
- ```
-
- **Function 5: `func_get_completed_travel_segments`** (New, for stitcher task)
- ```sql
- CREATE OR REPLACE FUNCTION func_get_completed_travel_segments(p_run_id TEXT)
- RETURNS TABLE(segment_index_out INTEGER, output_location_out TEXT) AS $$
- BEGIN
- RETURN QUERY
- SELECT
- CAST(params->>'segment_index' AS INTEGER) as segment_idx,
- output_location
- FROM
- tasks -- Assuming your table is named 'tasks', or use p_table_name if dynamic
- WHERE
- params->>'orchestrator_run_id' = p_run_id
- AND task_type = 'travel_segment'
- AND status = 'Complete'
- AND output_location IS NOT NULL
- ORDER BY
- segment_idx ASC;
- END;
- $$ LANGUAGE plpgsql;
- -- Consider SECURITY DEFINER if needed, depending on row-level security policies.
- ```
-
- * **Database Table:** The `func_initialize_tasks_table` will create the tasks table.
- * **Storage Bucket:** Create a storage bucket in Supabase (e.g., "videos"). Ensure it's public for URLs to work.
-
-7. **FFmpeg:** Ensure `ffmpeg` is installed and accessible in your system's PATH, as it's used for various video processing utilities.
-
-
-
-How it Works (Updated)
-
-1. **Task Definition (in the Database):**
- Tasks are stored as rows in the database. Key columns:
- * `task_id` (TEXT, UNIQUE)
- * `params` (JSON string/object): Contains all parameters for the task. For `steerable_motion.py` tasks, this will include a nested `orchestrator_details` payload for `travel_orchestrator` tasks, which is then passed to subsequent `travel_segment` and `travel_stitch` tasks.
- * `task_type` (TEXT, NOT NULL): Specifies the handler in `headless.py` (e.g., "standard_wgp_task", "generate_openpose", "travel_orchestrator", "travel_segment", "travel_stitch", "comfyui_workflow").
- * `status` (TEXT): "Queued", "In Progress", "Complete", "Failed".
- * `output_location` (TEXT): Local path (SQLite) or public URL (Supabase) of the final primary artifact.
-
- **Example of a `travel_orchestrator` task's `params` (simplified):**
- ```json
- {
- // task_id, task_type, etc. are top-level columns in the DB
- "orchestrator_details": {
- "orchestrator_task_id": "sm_travel_orchestrator_XYZ",
- "run_id": "run_ABC",
- "original_task_args": {
- "input_images": ["img1.png", "img2.png"],
- "base_prompts": ["prompt1"],
- /* ...other steerable_motion.py args... */
- },
- "original_common_args": {/* ...common args... */},
- "input_image_paths_resolved": ["/abs/path/to/img1.png", "/abs/path/to/img2.png"],
- "num_new_segments_to_generate": 1,
- "base_prompts_expanded": ["prompt1"],
- "segment_frames_expanded": [81],
- "frame_overlap_expanded": [16],
- // ... and many other parameters for headless tasks to consume ...
- "main_output_dir_for_run": "/path/to/steerable_motion_output"
- }
- }
- ```
+| Token type | Who should use it | Permissions in this repo |
+|------------|------------------|--------------------------|
+| **Service-role key** | Self-hosted backend worker(s) you control | Full access to every project row and storage object (admin). Recommended for private deployments. *Never expose this key in a client app.* |
+| **Personal access token (PAT)** or user JWT | Scripts/apps running on behalf of a **single Supabase user** | Access is limited to rows whose `project_id` you own. Provide `project_id` in each API call so the Edge Functions can enforce ownership. |
+
+If you set `SUPABASE_SERVICE_KEY` in the `.env`, the worker will authenticate as an admin. If you instead pass a PAT via the `Authorization: Bearer ` header when calling the Edge Functions, the backend will validate ownership before returning data.
+
+The worker will automatically upload finished videos to the bucket and store the public URL in the database.
+
+---
+
+## Adding Jobs to the Queue
+
+Use the `add_task.py` script to add tasks to the queue:
+
+```bash
+python add_task.py --type --params [--dependant-on ]
+```
+
+### Parameters:
+
+- `--type`: Task type string (required)
+- `--params`: JSON string with task payload OR @path-to-json-file (required)
+- `--dependant-on`: Optional task_id that this new task depends on
+
+### Character Escaping in JSON
+
+When using JSON strings in command line, special characters need proper escaping:
+
+#### Apostrophes and Quotes
+```bash
+# ✅ Correct: Escape apostrophes with backslash
+python add_task.py --type video_generation --params '{"prompt": "The soccer ball'\''s position changes", "model": "t2v"}'
+
+# ✅ Alternative: Use double quotes for the outer string
+python add_task.py --type video_generation --params "{\"prompt\": \"The soccer ball's position changes\", \"model\": \"t2v\"}"
+
+# ✅ Best practice: Use JSON files for complex prompts
+echo '{"prompt": "The soccer ball'\''s position changes dynamically", "model": "t2v"}' > task.json
+python add_task.py --type video_generation --params @task.json
+```
+
+#### Common Escaping Examples
+```bash
+# Quotes within prompts
+python add_task.py --type video_generation --params '{"prompt": "A character saying \"Hello world\"", "model": "t2v"}'
+
+# Backslashes (need double escaping)
+python add_task.py --type video_generation --params '{"prompt": "A path like C:\\\\Users\\\\folder", "model": "t2v"}'
+
+# Newlines and special characters
+python add_task.py --type video_generation --params '{"prompt": "Line 1\\nLine 2 with special chars: @#$%", "model": "t2v"}'
+```
+
+#### JSON File Method (Recommended for Complex Prompts)
+For prompts with many special characters, create a JSON file:
+
+**task_params.json:**
+```json
+{
+ "prompt": "The dragon's breath illuminates the knight's armor, creating a \"magical\" atmosphere with various symbols: @#$%^&*()",
+ "negative_prompt": "Don't show blurry or low-quality details",
+ "model": "t2v",
+ "resolution": "1280x720",
+ "seed": 42
+}
+```
+
+**Command:**
+```bash
+python add_task.py --type video_generation --params @task_params.json
+```
+
+### Examples:
+
+1. **Simple task with JSON string:**
+
+```bash
+python add_task.py --type single_image --params '{"prompt": "A beautiful landscape", "model": "t2v", "resolution": "512x512", "seed": 12345}'
+```
+
+2. **Task with JSON file:**
+
+```bash
+python add_task.py --type travel_orchestrator --params @my_task_params.json
+```
+
+3. **Task with dependency:**
+
+```bash
+python add_task.py --type generate_openpose --params '{"image_path": "input.png"}' --dependant-on task_12345
+```
+
+4. **Task with apostrophes in prompt:**
+
+```bash
+python add_task.py --type video_generation --params '{"prompt": "The hero'\''s journey begins at dawn", "model": "t2v", "resolution": "1280x720"}'
+```
+
+## Task Types and Examples
+
+### 1. Single Image Generation (`single_image`)
+
+Generates a single image from a text prompt.
+
+**Example payload:**
+```json
+{
+ "prompt": "A futuristic city skyline at night",
+ "model": "t2v",
+ "resolution": "512x512",
+ "seed": 12345,
+ "negative_prompt": "blurry, low quality",
+ "use_causvid_lora": true
+}
+```
+
+### 2. Travel Orchestrator (`travel_orchestrator`)
+
+Manages complex travel sequences between multiple images.
+
+**Example payload:**
+```json
+{
+ "project_id": "my_travel_project",
+ "orchestrator_details": {
+ "run_id": "travel_20250814",
+ "input_image_paths_resolved": ["image1.png", "image2.png", "image3.png"],
+ "parsed_resolution_wh": "512x512",
+ "model_name": "vace_14B",
+ "use_causvid_lora": true,
+ "num_new_segments_to_generate": 2,
+ "base_prompts_expanded": ["Futuristic landscape", "Alien world"],
+ "negative_prompts_expanded": ["blurry", "low quality"],
+ "segment_frames_expanded": [81, 81],
+ "frame_overlap_expanded": [12, 12],
+ "fps_helpers": 16,
+ "fade_in_params_json_str": "{\"low_point\": 0.0, \"high_point\": 1.0, \"curve_type\": \"ease_in_out\", \"duration_factor\": 0.0}",
+ "fade_out_params_json_str": "{\"low_point\": 0.0, \"high_point\": 1.0, \"curve_type\": \"ease_in_out\", \"duration_factor\": 0.0}",
+ "seed_base": 11111,
+ "main_output_dir_for_run": "./outputs",
+ "debug_mode_enabled": true,
+ "skip_cleanup_enabled": true
+ }
+}
+```
+
+### 3. Generate OpenPose (`generate_openpose`)
+
+Generates OpenPose skeleton images from input images.
+
+**Example payload:**
+```json
+{
+ "image_path": "input.png",
+ "output_dir": "openpose_output"
+}
+```
+
+### 4. Different Perspective (`different_perspective_orchestrator`)
+
+Generates videos from different perspectives of a single image.
+
+**Example payload:**
+```json
+{
+ "project_id": "perspective_project",
+ "run_id": "perspective_20250814",
+ "input_image_path": "input.png",
+ "prompt": "Cinematic view from a different angle",
+ "model_name": "vace_14B",
+ "resolution": "700x400",
+ "fps_helpers": 16,
+ "output_video_frames": 30,
+ "seed": 11111,
+ "use_causvid_lora": true,
+ "debug_mode": true,
+ "skip_cleanup": true,
+ "perspective_type": "pose"
+}
+```
+
+## Video Generation Examples with Different Models
+
+### Text-to-Video Examples
-2. **Polling Mechanism:** `headless.py` polls for 'Queued' tasks.
+#### Wan2.1 Text-to-Video 14B (t2v):
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A spaceship traveling through hyperspace",
+ "model": "t2v",
+ "resolution": "1280x720",
+ "seed": 42,
+ "video_length": 81,
+ "num_inference_steps": 30
+}'
+```
+
+#### Wan2.1 Text-to-Video 1.3B (t2v_1.3B):
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A futuristic city with flying cars",
+ "model": "t2v_1.3B",
+ "resolution": "832x480",
+ "seed": 123,
+ "video_length": 81,
+ "num_inference_steps": 30
+}'
+```
-3. **Task Processing:**
- * `headless.py` claims a task, updates status to 'In Progress'.
- * It calls the appropriate handler based on `task_type`.
- * For `travel_orchestrator`: Enqueues the first `travel_segment` task.
- * For `travel_segment`: Creates guide videos, runs WGP/ComfyUI generation (as a sub-task), and then enqueues the next segment or the `travel_stitch` task.
- * For `travel_stitch`: Collects all segment videos, stitches them (with crossfades), optionally upscales (as a sub-task), and saves the final video.
- * Outputs are handled (local save or Supabase upload).
+#### With Single LoRA:
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A spaceship traveling through hyperspace",
+ "model": "t2v",
+ "resolution": "1280x720",
+ "seed": 42,
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "use_causvid_lora": true,
+ "lora_name": "your_lora_name"
+}'
+```
-4. **Task Completion:** Status becomes 'Complete'/'Failed', `output_location` is updated.
-
+#### With Multiple LoRAs:
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A cyberpunk samurai with neon lights",
+ "model": "t2v",
+ "resolution": "1280x720",
+ "seed": 42,
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "activated_loras": ["cyberpunk_style", "samurai_character", "neon_effects"],
+ "loras_multipliers": "1.2 0.8 1.0"
+}'
+```
-## Usage
+### Image-to-Video Examples
-### 1. Start the Headless Worker:
+#### Image-to-Video 480p (i2v):
```bash
-# Ensure your .env file is configured if using Supabase
-python Wan2GP/headless.py --main-output-dir ./my_video_outputs --poll-interval 5
-# Add --debug for verbose logs
+python add_task.py --type video_generation --params '{
+ "prompt": "Transform this image into a dynamic video",
+ "model": "i2v",
+ "resolution": "832x480",
+ "seed": 101,
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_image_here"
+}'
```
-`headless.py` will continuously monitor the task queue.
-### 2. Queue Tasks using `steerable_motion.py`:
+#### Image-to-Video 720p (i2v_720p):
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Create a cinematic video from this image",
+ "model": "i2v_720p",
+ "resolution": "1280x720",
+ "seed": 202,
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_image_here"
+}'
+```
-Once `headless.py` is running, execute the following in another terminal:
+### ControlNet Examples
-**A. Example `travel_between_images` with Sample Images:**
+#### Vace 14B (vace_14B):
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A character in a dynamic pose",
+ "model": "vace_14B",
+ "resolution": "832x480",
+ "seed": 456,
+ "video_length": 81,
+ "num_inference_steps": 30
+}'
+```
+#### Vace 1.3B (vace_1.3B):
```bash
-python steerable_motion.py travel_between_images \\
- --input_images samples/image_1.png samples/image_2.png samples/image_3.png \\
- --base_prompts "Transitioning from red" "Moving to green" \\
- --resolution "320x240" \\
- --segment_frames 30 \\
- --frame_overlap 10 \\
- --model_name "vace_14B" \\
- --seed 789 \\
- --output_dir ./my_video_outputs \\
- --debug
+python add_task.py --type video_generation --params '{
+ "prompt": "A person dancing in the rain",
+ "model": "vace_1.3B",
+ "resolution": "832x480",
+ "seed": 789,
+ "video_length": 81,
+ "num_inference_steps": 30
+}'
```
-**B. Example `different_pose` with a Sample Image:**
+### Advanced Image-to-Video Examples
+
+#### Fun InP 14B (fun_inp) - Supports End Frames:
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Smooth transition between start and end images",
+ "model": "fun_inp",
+ "resolution": "832x480",
+ "seed": 303,
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_start_image",
+ "image_end": "base64_encoded_end_image"
+}'
+```
+#### Fun InP 1.3B (fun_inp_1.3B):
```bash
-python steerable_motion.py different_pose \\
- --input_image samples/image_1.png \\
- --prompt "A red square, now animated and waving" \\
- --resolution "320x240" \\
- --output_video_frames 30 \\
- --model_name "vace_14B" \\
- --seed 101 \\
- --output_dir ./my_video_outputs \\
- --debug
+python add_task.py --type video_generation --params '{
+ "prompt": "Creative morphing between two images",
+ "model": "fun_inp_1.3B",
+ "resolution": "832x480",
+ "seed": 404,
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_start_image"
+}'
```
-Remember to have `headless.py` running in a separate terminal to process these queued tasks.
+---
+
+## Complete Model Reference
+
+Based on the Wan2GP codebase analysis, here are all available models and their recommended parameters:
+
+### Text-to-Video Models
+
+#### Standard Models
+- **`t2v`** - Wan2.1 Text-to-Video 14B (recommended for high quality)
+- **`t2v_1.3B`** - Wan2.1 Text-to-Video 1.3B (faster, lower VRAM)
+
+#### Specialized Text-to-Video Models
+- **`sky_df_1.3B`** - SkyReels2 Diffusion Forcing 1.3B (for long videos)
+- **`sky_df_14B`** - SkyReels2 Diffusion Forcing 14B (for long videos, higher quality)
+- **`sky_df_720p_14B`** - SkyReels2 Diffusion Forcing 720p 14B (for long 720p videos)
+- **`ltxv_13B`** - LTX Video 13B (fast, up to 260 frames)
+- **`ltxv_13B_distilled`** - LTX Video Distilled 13B (very fast)
+- **`hunyuan`** - Hunyuan Video 720p 13B (best quality text-to-video)
+- **`moviigen`** - MoviiGen 1080p 14B (cinematic quality, 21:9 ratio)
+
+### Image-to-Video Models
+
+#### Standard Image-to-Video
+- **`i2v`** - Wan2.1 Image-to-Video 480p 14B
+- **`i2v_720p`** - Wan2.1 Image-to-Video 720p 14B (higher resolution)
+- **`hunyuan_i2v`** - Hunyuan Video Image-to-Video 720p 13B
+
+#### Advanced Image-to-Video
+- **`fun_inp`** - Fun InP Image-to-Video 14B (supports end frames)
+- **`fun_inp_1.3B`** - Fun InP Image-to-Video 1.3B (faster)
+- **`flf2v_720p`** - First Last Frame 2 Video 720p 14B (official start/end frame model)
+- **`fantasy`** - Fantasy Speaking 720p 14B (with audio input support)
-### Command-Line Arguments for `headless.py`:
-(Content from previous README is largely still valid here, ensure it's present)
-* **Server Settings:**
- * `--db-file` (Used for SQLite mode only): Path to the SQLite database file. Defaults to `tasks.db` (or value from `SQLITE_DB_PATH_ENV` in `.env`).
- * `--main-output-dir`: Base directory for outputs (e.g., where `steerable_motion.py` tells tasks to save, and where `headless.py` might save for non-`steerable_motion` tasks if they don't specify a full sub-path). Defaults to `./outputs`.
- * `--poll-interval`: How often (in seconds) to check the database for new tasks. Defaults to 10 seconds.
- * `--debug`: Enable verbose debug logging.
+### ControlNet Models
-* **WGP Global Config Overrides (Optional - Applied once at server start):**
- (List of `--wgp-*` arguments from previous README is still valid)
+#### Vace ControlNet (for pose, depth, custom control)
+- **`vace_1.3B`** - Vace ControlNet 1.3B
+- **`vace_14B`** - Vace ControlNet 14B (recommended)
-## Advanced: Video-Guided Generation (VACE)
+#### Specialized Control Models
+- **`phantom_1.3B`** - Phantom 1.3B (object/person transfer)
+- **`phantom_14B`** - Phantom 14B (object/person transfer, higher quality)
+- **`hunyuan_custom`** - Hunyuan Custom 720p 13B (person identity preservation)
+- **`hunyuan_avatar`** - Hunyuan Avatar 720p 13B (audio-driven animation)
-The `params` field for each task in the database can include `video_prompt_type`, `video_guide_path`, `video_mask_path`, and `keep_frames_video_guide` as described below to control VACE features. The pose, depth or greyscale information is always extracted **from the `video_guide_path` file**; the optional `video_mask_path` is only for in/out-painting control.
+### Utility Models
+- **`recam_1.3B`** - ReCamMaster 1.3B (camera movement replay)
-**Key JSON fields in task `params` for VACE:**
+### Resolution Guidelines by Model
+
+| Model Type | Recommended Resolutions | Max Resolution |
+|------------|------------------------|----------------|
+| `t2v_1.3B`, `fun_inp_1.3B` | `832x480`, `480x832` | 848x480 equivalent |
+| `t2v`, `i2v`, `fun_inp` | `832x480`, `1280x720` | No strict limit |
+| `i2v_720p`, `flf2v_720p` | `1280x720`, `720x1280` | 720p optimized |
+| `sky_df_720p_14B` | `1280x720`, `960x544` | 720p optimized |
+| `hunyuan*` | `1280x720` | 720p optimized |
+| `ltxv*` | `832x480`, `1280x720` | Flexible |
+| `moviigen` | `1920x832` (21:9) | 1080p cinematic |
+
+### Frame Length Guidelines
+
+| Model Type | Default Frames | Max Frames | FPS |
+|------------|----------------|------------|-----|
+| Standard Wan models | 81 | 193 | 16 |
+| `sky_df*` (Diffusion Forcing) | 97 | 737 | 24 |
+| `ltxv*` | 97 | 737 | 30 |
+| `hunyuan*` | 97 | 337-401 | 24-25 |
+| `fantasy` | 81 | 233 | 23 |
+| `recam_1.3B` | 81 | 193 (locked) | 16 |
+
+### Example with All Common Parameters
+
+```bash
+python add_task.py --type single_image --params '{
+ "prompt": "A majestic dragon flying over mountains at sunset",
+ "negative_prompt": "blurry, low quality, distorted",
+ "model": "t2v",
+ "resolution": "1280x720",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "guidance_scale": 5.0,
+ "flow_shift": 5.0,
+ "seed": 42,
+ "repeat_generation": 1,
+ "use_causvid_lora": false
+}'
+```
+
+---
+
+## Complete Parameter Schema
+
+### Task Types
+
+| Task Type | Description | Use Case |
+|-----------|-------------|----------|
+| `video_generation` | Generate videos from text/image prompts | Most common video generation tasks |
+| `single_image` | Generate single images | Image generation only (no video_length) |
+| `travel_orchestrator` | Complex multi-segment video sequences | Advanced orchestrated workflows |
+| `generate_openpose` | Extract pose information from images | Preprocessing for pose-based generation |
+| `different_perspective_orchestrator` | Generate videos from different angles | Perspective-based video generation |
+
+### All Supported Parameters
+
+Based on the [`generate_video`](Wan2GP/wgp.py:2694) function analysis:
+
+#### Core Parameters
```json
{
- // ... other parameters ...
- "video_prompt_type": "PV", // String of capital letters described below
- "video_guide_path": "path/to/your/control_video.mp4", // Path to the control video, accessible by the server
- "video_mask_path": "path/to/your/mask_video.mp4", // (Optional) Path to a mask video, accessible by the server
- "keep_frames_video_guide": "1:10 -1", // (Optional) List of whole frames to preserve from the guide
- "image_refs_paths": ["path/to/ref_image1.png"], // (Optional) For 'I' in video_prompt_type
- // ... other parameters ...
+ "prompt": "string - Text description of desired video content",
+ "negative_prompt": "string - What to avoid in generation (optional)",
+ "model": "string - Model type (see model list below)",
+ "resolution": "string - WIDTHxHEIGHT format (e.g., '1280x720')",
+ "video_length": "integer - Number of frames (model-dependent limits)",
+ "num_inference_steps": "integer - Denoising steps (1-100, typically 20-50)",
+ "seed": "integer - Random seed (-1 for random, 0-999999999)",
+ "repeat_generation": "integer - Number of videos to generate (1-25)"
}
```
-### 1. `video_prompt_type` letters
-You can concatenate several letters (order does not matter). The most common combinations are shown in the table below:
+#### Advanced Parameters
+```json
+{
+ "guidance_scale": "float - CFG scale (1.0-20.0, typically 5.0-7.5)",
+ "flow_shift": "float - Flow shift parameter (0.0-25.0, typically 3.0-8.0)",
+ "embedded_guidance_scale": "float - For Hunyuan models (1.0-20.0)",
+ "audio_guidance_scale": "float - For fantasy model with audio (1.0-20.0)"
+}
+```
+
+#### Image Input Parameters
+```json
+{
+ "image_start": "string - Base64 encoded start image for i2v models",
+ "image_end": "string - Base64 encoded end image (fun_inp models only)",
+ "image_refs": "array - Base64 encoded reference images (phantom/hunyuan_custom)",
+ "image_prompt_type": "string - 'S' (start only) or 'SE' (start+end)"
+}
+```
+
+#### Video Input Parameters
+```json
+{
+ "video_source": "string - Path to source video (diffusion_forcing/ltxv/recam)",
+ "video_guide": "string - Path to control video (vace models)",
+ "video_mask": "string - Path to mask video (vace inpainting)",
+ "keep_frames_video_source": "string - Frames to keep from source",
+ "keep_frames_video_guide": "string - Frames to keep from guide"
+}
+```
+
+#### Audio Parameters
+```json
+{
+ "audio_guide": "string - Path to audio file (fantasy/hunyuan_avatar models)"
+}
+```
+
+#### LoRA Parameters
+```json
+{
+ "use_causvid_lora": "boolean - Enable LoRA usage",
+ "lora_name": "string - Specific LoRA name to use",
+ "activated_loras": "array - List of LoRA names",
+ "loras_multipliers": "string - LoRA strength values"
+}
+```
+
+#### Advanced Control Parameters
+```json
+{
+ "video_prompt_type": "string - Control type for vace ('I'=image, 'V'=video, 'M'=mask)",
+ "model_mode": "integer - Generation mode (0=sync, 5=async for diffusion_forcing)",
+ "sliding_window_size": "integer - Window size for long videos",
+ "sliding_window_overlap": "integer - Frame overlap between windows",
+ "tea_cache_setting": "float - Speed optimization (0=off, 1.5-2.5=faster)",
+ "RIFLEx_setting": "integer - Long video support (0=auto, 1=on, 2=off)"
+}
+```
+
+---
+
+## LoRA Usage Guide
+
+LoRAs (Low-Rank Adaptations) allow you to apply specialized styles, characters, or effects to your video generation. This guide covers everything you need to know about using LoRAs effectively.
+
+### Quick Start
+
+#### Single LoRA Usage
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A fantasy dragon in medieval style",
+ "model": "t2v",
+ "lora_name": "fantasy_medieval_style"
+}'
+```
+
+#### Multiple LoRAs with Custom Strengths
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A cyberpunk samurai with neon lights",
+ "model": "t2v",
+ "activated_loras": ["cyberpunk_style", "samurai_character", "neon_effects"],
+ "loras_multipliers": "1.2 0.8 1.0"
+}'
+```
+
+### LoRA Parameters Reference
+
+| Parameter | Type | Description | Example |
+|-----------|------|-------------|---------|
+| `lora_name` | string | Single LoRA filename (without extension) | `"anime_style"` |
+| `activated_loras` | array | Multiple LoRA filenames (without extensions) | `["style1", "char1", "effect1"]` |
+| `loras_multipliers` | string | Space-separated strength values (1.0 = default) | `"1.2 0.8 1.0"` |
+| `use_causvid_lora` | boolean | Enable/disable LoRA usage | `true` or `false` |
+
+### File Naming and Location
+
+#### ⚠️ Critical: Filename Requirements
+- **Always use filenames WITHOUT the `.safetensors` extension** in parameters
+- **Avoid dots in LoRA filenames** (known bug in matching system)
+- Use underscores instead: `anime_style_v2` instead of `anime.style.v2`
+
+#### Directory Structure
+Place your LoRA files in the appropriate model-specific directories:
+
+```
+Wan2GP/
+├── loras/ # Text-to-Video LoRAs (t2v, sky_df, etc.)
+├── loras_i2v/ # Image-to-Video LoRAs (i2v, fun_inp, etc.)
+├── loras_hunyuan/ # Hunyuan Video LoRAs
+├── loras_hunyuan_i2v/ # Hunyuan Image-to-Video LoRAs
+└── loras_ltxv/ # LTX Video LoRAs
+```
+
+#### Model-Specific LoRA Locations
+
+| Model Family | LoRA Directory | Supported Extensions |
+|--------------|----------------|---------------------|
+| `t2v`, `t2v_1.3B`, `sky_df_*`, `moviigen` | `Wan2GP/loras/` | `.safetensors`, `.sft` |
+| `i2v`, `i2v_720p`, `fun_inp*`, `flf2v_720p` | `Wan2GP/loras_i2v/` | `.safetensors`, `.sft` |
+| `hunyuan`, `hunyuan_custom` | `Wan2GP/loras_hunyuan/` | `.safetensors`, `.sft` |
+| `hunyuan_i2v`, `hunyuan_avatar` | `Wan2GP/loras_hunyuan_i2v/` | `.safetensors`, `.sft` |
+| `ltxv_13B`, `ltxv_13B_distilled` | `Wan2GP/loras_ltxv/` | `.safetensors`, `.sft` |
+
+### Usage Examples
+
+#### Basic Single LoRA
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A majestic castle in anime style",
+ "model": "t2v",
+ "resolution": "1280x720",
+ "lora_name": "anime_architecture",
+ "seed": 12345
+}'
+```
+
+#### Multiple LoRAs with Different Strengths
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A cyberpunk warrior with glowing weapons in neon city",
+ "model": "t2v",
+ "resolution": "1280x720",
+ "activated_loras": ["cyberpunk_style", "warrior_character", "neon_effects", "weapon_glow"],
+ "loras_multipliers": "1.3 1.0 0.8 1.1",
+ "seed": 54321
+}'
+```
+
+#### Image-to-Video with LoRA
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Transform this portrait into anime style with movement",
+ "model": "i2v_720p",
+ "resolution": "1280x720",
+ "image_start": "base64_encoded_image_here",
+ "lora_name": "anime_portrait_style",
+ "seed": 98765
+}'
+```
+
+#### Advanced: Dynamic LoRA Strengths
+You can vary LoRA strength over time using comma-separated values:
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Gradual style transformation",
+ "model": "t2v",
+ "activated_loras": ["style_a", "style_b"],
+ "loras_multipliers": "1.0,0.8,0.6,0.4 0.0,0.2,0.4,0.6",
+ "num_inference_steps": 30
+}'
+```
+
+### LoRA Strength Guidelines
+
+| Strength Range | Effect | Use Case |
+|----------------|--------|----------|
+| `0.1 - 0.4` | Subtle influence | Minor style adjustments |
+| `0.5 - 0.8` | Moderate effect | Balanced style application |
+| `0.9 - 1.2` | Strong influence | Dominant style (recommended) |
+| `1.3 - 2.0` | Very strong | Extreme stylization (may cause artifacts) |
+
+### Troubleshooting
+
+#### Common Issues
+
+**LoRA Not Loading:**
+- ✅ Check filename doesn't include `.safetensors` extension
+- ✅ Verify file is in correct directory for your model
+- ✅ Ensure filename doesn't contain dots (use underscores)
+- ✅ Check file isn't corrupted: `python check_loras.py`
+
+**Poor Quality Results:**
+- Try adjusting LoRA strength (0.8-1.2 range usually works best)
+- Ensure LoRA is compatible with your model type
+- Check if multiple LoRAs are conflicting
+
+**File Not Found Errors:**
+```bash
+# Validate all LoRA files
+python check_loras.py
+
+# Fix corrupted LoRA files
+python check_loras.py --fix
+```
+
+#### Debug Commands
+```bash
+# List all available LoRAs for t2v models
+ls -la Wan2GP/loras/
+
+# Check LoRA file integrity
+python check_loras.py --verbose
+
+# Test LoRA loading (dry run)
+python add_task.py --type video_generation --params '{"model": "t2v", "lora_name": "test_lora", "prompt": "test"}' --dry-run
+```
+
+### Best Practices
+
+1. **Naming Convention**: Use descriptive names with underscores: `fantasy_medieval_v2`
+2. **Organization**: Group related LoRAs in subdirectories within the main LoRA folder
+3. **Testing**: Start with single LoRAs before combining multiple ones
+4. **Strength Tuning**: Begin with 1.0 strength and adjust based on results
+5. **Compatibility**: Ensure LoRAs are trained for your specific model type
+6. **Backup**: Keep backups of working LoRA configurations
+
+### Advanced Features
+
+#### Automatic LoRA Download (Experimental)
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Studio Ghibli style animation",
+ "model": "t2v",
+ "additional_loras": {
+ "https://huggingface.co/user/ghibli-lora/resolve/main/ghibli_style.safetensors": 1.0
+ }
+}'
+```
+
+**Note**: This feature automatically downloads and caches LoRAs from URLs. Use with trusted sources only.
+
+---
+
+## Task Management Operations
+
+### Monitoring Tasks and Queue Status
+
+#### View All Tasks with Details (SQLite)
+```bash
+# View all tasks with key information
+sqlite3 tasks.db "SELECT id, task_type, status, created_at, updated_at FROM tasks ORDER BY created_at DESC;"
+
+# View detailed task information including prompts and parameters
+sqlite3 tasks.db "SELECT id, task_type, status, params, output_location, created_at FROM tasks ORDER BY created_at DESC;"
+
+# View only queued tasks
+sqlite3 tasks.db "SELECT id, task_type, status, created_at FROM tasks WHERE status = 'Queued' ORDER BY created_at ASC;"
+
+# View currently processing tasks
+sqlite3 tasks.db "SELECT id, task_type, status, generation_started_at FROM tasks WHERE status = 'In Progress';"
+
+# View completed tasks with outputs
+sqlite3 tasks.db "SELECT id, task_type, output_location, generation_processed_at FROM tasks WHERE status = 'Complete' ORDER BY generation_processed_at DESC;"
+```
+
+#### Extract Specific Task Details (SQLite)
+```bash
+# Get full task details including prompt and LoRA information
+sqlite3 tasks.db "SELECT id, task_type, params FROM tasks WHERE id = 'your-task-id';"
+
+# Get task status and timing information
+sqlite3 tasks.db "SELECT id, status, created_at, generation_started_at, generation_processed_at FROM tasks WHERE id = 'your-task-id';"
+
+# View task parameters in a more readable format (requires jq)
+sqlite3 -json tasks.db "SELECT id, task_type, params FROM tasks WHERE id = 'your-task-id';" | jq '.[0].params | fromjson'
+```
+
+#### Monitor Queue Statistics (SQLite)
+```bash
+# Get queue summary
+sqlite3 tasks.db "SELECT status, COUNT(*) as count FROM tasks GROUP BY status;"
+
+# Get task type distribution
+sqlite3 tasks.db "SELECT task_type, COUNT(*) as count FROM tasks GROUP BY task_type ORDER BY count DESC;"
+
+# Get recent task activity (last 24 hours)
+sqlite3 tasks.db "SELECT id, task_type, status, created_at FROM tasks WHERE datetime(created_at) > datetime('now', '-1 day') ORDER BY created_at DESC;"
+
+# Get average processing time for completed tasks
+sqlite3 tasks.db "SELECT task_type, AVG((julianday(generation_processed_at) - julianday(generation_started_at)) * 24 * 60) as avg_minutes FROM tasks WHERE status = 'Complete' AND generation_started_at IS NOT NULL AND generation_processed_at IS NOT NULL GROUP BY task_type;"
+```
+
+#### Advanced Task Queries (SQLite)
+```bash
+# Find tasks using specific LoRAs
+sqlite3 tasks.db "SELECT id, task_type, created_at FROM tasks WHERE params LIKE '%lora_name%' OR params LIKE '%activated_loras%';"
+
+# Find tasks with specific models
+sqlite3 tasks.db "SELECT id, task_type, created_at FROM tasks WHERE params LIKE '%\"model\":\"t2v\"%';"
+
+# Find failed tasks with error details
+sqlite3 tasks.db "SELECT id, task_type, output_location as error_message, created_at FROM tasks WHERE status = 'Failed' ORDER BY created_at DESC;"
+
+# Find tasks by prompt keywords
+sqlite3 tasks.db "SELECT id, task_type, created_at FROM tasks WHERE params LIKE '%spaceship%' OR params LIKE '%dragon%';"
+```
+
+#### Real-time Queue Monitoring (SQLite)
+```bash
+# Watch queue status (updates every 5 seconds)
+watch -n 5 'sqlite3 tasks.db "SELECT status, COUNT(*) FROM tasks GROUP BY status;"'
+
+# Monitor recent task activity
+watch -n 10 'sqlite3 tasks.db "SELECT id, task_type, status, datetime(created_at, \"localtime\") as local_time FROM tasks ORDER BY created_at DESC LIMIT 10;"'
+
+# Track processing times
+watch -n 30 'sqlite3 tasks.db "SELECT id, task_type, status, ROUND((julianday(\"now\") - julianday(generation_started_at)) * 24 * 60, 1) as minutes_running FROM tasks WHERE status = \"In Progress\";"'
+```
+
+#### Supabase Mode
+Use the Supabase dashboard or Edge Functions to check task status. You can also query the tasks table directly:
+
+```sql
+-- View all tasks
+SELECT id, task_type, status, created_at, updated_at FROM tasks ORDER BY created_at DESC;
+
+-- View task details with parameters
+SELECT id, task_type, status, params, output_location FROM tasks WHERE id = 'your-task-id';
+
+-- Monitor queue status
+SELECT status, COUNT(*) FROM tasks GROUP BY status;
+```
+
+### Canceling Tasks and Queue Management
+
+#### Cancel Individual Tasks
+
+**Cancel a specific queued task (SQLite):**
+```bash
+# Mark a specific task as cancelled
+sqlite3 tasks.db "UPDATE tasks SET status = 'Failed', output_location = 'Cancelled by user', updated_at = datetime('now') WHERE id = 'your-task-id' AND status = 'Queued';"
+
+# Verify cancellation
+sqlite3 tasks.db "SELECT id, status, output_location FROM tasks WHERE id = 'your-task-id';"
+```
+
+**Cancel a specific task (Python script):**
+```python
+# Create cancel_task.py
+from source import db_operations as db_ops
+
+# Cancel a specific task
+task_id = "your-task-id"
+db_ops.update_task_status(task_id, "Failed", "Cancelled by user")
+print(f"Task {task_id} has been cancelled")
+```
+
+#### Cancel All Queued Tasks
+
+**Remove all queued tasks (SQLite):**
+```bash
+# Show queued tasks before cancellation
+sqlite3 tasks.db "SELECT COUNT(*) as queued_count FROM tasks WHERE status = 'Queued';"
+
+# Cancel all queued tasks (marks as Failed instead of deleting)
+sqlite3 tasks.db "UPDATE tasks SET status = 'Failed', output_location = 'Bulk cancelled by user', updated_at = datetime('now') WHERE status = 'Queued';"
+
+# Alternative: Delete all queued tasks completely
+sqlite3 tasks.db "DELETE FROM tasks WHERE status = 'Queued';"
+
+# Verify results
+sqlite3 tasks.db "SELECT status, COUNT(*) FROM tasks GROUP BY status;"
+```
+
+**Bulk cancel script (Python):**
+```bash
+# Create bulk_cancel.py and run it
+python -c "
+import sqlite3
+from datetime import datetime
+
+conn = sqlite3.connect('tasks.db')
+cursor = conn.cursor()
+
+# Count queued tasks
+cursor.execute('SELECT COUNT(*) FROM tasks WHERE status = \"Queued\"')
+queued_count = cursor.fetchone()[0]
+print(f'Found {queued_count} queued tasks')
+
+if queued_count > 0:
+ # Mark all as cancelled
+ cursor.execute('UPDATE tasks SET status = \"Failed\", output_location = \"Bulk cancelled\", updated_at = ? WHERE status = \"Queued\"', (datetime.utcnow().isoformat() + 'Z',))
+ print(f'Cancelled {cursor.rowcount} tasks')
+
+conn.commit()
+conn.close()
+"
+```
+
+#### Cancel Specific Task Types
+```bash
+# Cancel all video_generation tasks in queue
+sqlite3 tasks.db "UPDATE tasks SET status = 'Failed', output_location = 'Cancelled - video_generation tasks' WHERE task_type = 'video_generation' AND status = 'Queued';"
+
+# Cancel all travel_orchestrator tasks
+sqlite3 tasks.db "UPDATE tasks SET status = 'Failed', output_location = 'Cancelled - travel tasks' WHERE task_type = 'travel_orchestrator' AND status = 'Queued';"
+
+# Delete failed tasks (cleanup)
+sqlite3 tasks.db "DELETE FROM tasks WHERE status = 'Failed';"
+```
+
+#### Stop Currently Running Tasks
+
+**⚠️ Important**: The Headless-Wan2GP system does not have built-in graceful task stopping mechanisms. Currently running tasks cannot be stopped cleanly through the database.
+
+**Manual intervention options:**
+
+1. **Restart the worker process:**
+```bash
+# Stop the headless worker (Ctrl+C or kill process)
+# The current task will be marked as failed on restart
+# Then restart:
+python headless.py --main-output-dir ./outputs
+```
+
+2. **Force mark running task as failed:**
+```bash
+# Find currently running tasks
+sqlite3 tasks.db "SELECT id, task_type, generation_started_at FROM tasks WHERE status = 'In Progress';"
+
+# Force mark as failed (will cause worker to skip it)
+sqlite3 tasks.db "UPDATE tasks SET status = 'Failed', output_location = 'Manually stopped', updated_at = datetime('now') WHERE status = 'In Progress';"
+```
+
+3. **System-level process termination:**
+```bash
+# Find Python processes
+ps aux | grep headless.py
+
+# Terminate specific process (replace PID)
+kill -TERM
-| Letter | Meaning in the UI | What the code does |
-| :----- | :----------------------- | :--------------------------------------------------------------------------------- |
-| `V` | Use the guide video as is | Feeds the RGB frames directly |
-| `P` | Pose guidance | Extracts DW-Pose skeletons from the guide video and feeds those instead of RGB |
-| `D` | Depth guidance | Extracts MiDaS depth maps and feeds them |
-| `G` | Greyscale guidance | Converts the guide video to grey before feeding |
-| `C` | Colour-transfer | Alias for `G`, lets the model recolourise the B&W guide |
-| `M` | External mask video | Expect `video_mask_path`; white = in/out-paint, black = keep |
-| `I` | Image references | Supply `image_refs_paths`; those images are encoded and prepended to the latent |
-| `O` | Original pixels on black | When present the pixels under the **black** part of the mask are copied through untouched (classic in-painting) |
+# Force kill if needed (may cause corruption)
+kill -KILL
+```
-Examples:
-* `"PV"` – feed **P**ose extracted from the guide **V**ideo.
-* `"DMV"` – **D**epth maps + mask **M** + original video **V**.
-* `"VMO"` – video + mask + keep pixels under black mask.
+**Note**: Stopping tasks mid-generation may leave temporary files and incomplete outputs. The system will clean up on the next restart.
-### 2. How the mask works
-A mask frame is converted to a single channel in range [0, 1].
+#### Advanced Queue Management
-| Value | Effect |
-| :-------- | :------------------------------------------------------------------ |
-| 0 (black) | Keep original RGB pixels (or copy the guide when `O` is active) |
-| 1 (white) | Pixels are blanked and the model is free to generate new content |
+**Pause queue processing:**
+```bash
+# Create a pause file to stop the worker from processing new tasks
+touch PAUSE_PROCESSING
-If you do **not** supply `video_mask_path`, the pipeline internally builds an all-white mask for every frame that is *not* listed in `keep_frames_video_guide` (see next section).
+# Remove to resume
+rm PAUSE_PROCESSING
+```
-### 3. `keep_frames_video_guide`
-Optional string that lists *whole frames* to keep. Syntax:
+**Priority task management:**
+```bash
+# Move specific task to front of queue (change created_at)
+sqlite3 tasks.db "UPDATE tasks SET created_at = datetime('now', '-1 day') WHERE id = 'priority-task-id';"
-* single index – positive (1 = first frame) or negative (-1 = last frame)
-* range – `a:b` is inclusive and uses the same indexing rules
-* separate items with a space.
+# Delay specific task (move to back)
+sqlite3 tasks.db "UPDATE tasks SET created_at = datetime('now', '+1 day') WHERE id = 'delay-task-id';"
+```
-Examples for an 81-frame clip:
+### Monitoring Task Progress
+#### Real-time Monitoring
+```bash
+# Watch task status updates
+watch -n 5 'sqlite3 tasks.db "SELECT id, task_type, status, updated_at FROM tasks ORDER BY updated_at DESC LIMIT 10;"'
```
-"" // empty ⇒ keep every frame (default)
-"1:20" // keep frames 0-19, generate the rest
-"1:20 -1" // keep frames 0-19 and the last frame
-"-10:-1" // keep the last 10 frames
+
+#### Check LoRA File Integrity
+```bash
+# Validate all LoRA files
+python check_loras.py
+
+# Fix corrupted LoRA files
+python check_loras.py --fix
```
-Frames not listed are zeroed and their mask set to white, so the network repaints them completely.
+---
+
+## Complete Model Examples
-If you only need per-pixel control you can omit `keep_frames_video_guide` and drive everything with the mask video alone.
+### Text-to-Video Models
-### 4. Quick recipes
+#### Standard Wan2.1 Models
-| Goal | Fields to set (`params` in database task) |
-| :--------------------------------------------------- | :--------------------------------------------------------------------------------------------------------- |
-| Transfer only motion from a control video | `"video_prompt_type": "PV"`, `"video_guide_path": "path/to/control.mp4"` |
-| Depth-guided generation | `"video_prompt_type": "DV"`, `"video_guide_path": "path/to/control.mp4"` |
-| Classic in-painting with explicit mask | `"video_prompt_type": "MV"`, `"video_guide_path": "path/to/guide.mp4"`, `"video_mask_path": "path/to/mask.mp4"` |
-| Freeze first 20 & last frame, generate the rest | `"video_prompt_type": "VM"`, `"keep_frames_video_guide": "1:20 -1"`, `"video_guide_path": "path/to/guide.mp4"` (mask video can be all-white or omitted if guide is sufficient) |
+**t2v (14B) - High Quality Text-to-Video:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A majestic eagle soaring over snow-capped mountains",
+ "model": "t2v",
+ "resolution": "1280x720",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "guidance_scale": 5.0,
+ "flow_shift": 5.0,
+ "seed": 12345
+}'
+```
-This section should give you enough vocabulary to combine mask videos, depth/pose guidance and frame-freezing without modifying the code.
+**t2v_1.3B - Faster, Lower VRAM:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A robot walking through a neon-lit cyberpunk street",
+ "model": "t2v_1.3B",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 25,
+ "guidance_scale": 5.0,
+ "flow_shift": 5.0,
+ "seed": 54321
+}'
+```
+
+#### Specialized Text-to-Video Models
+
+**sky_df_14B - Long Video Generation:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Time-lapse of clouds forming and dispersing over a landscape",
+ "model": "sky_df_14B",
+ "resolution": "960x544",
+ "video_length": 200,
+ "num_inference_steps": 30,
+ "guidance_scale": 6.0,
+ "flow_shift": 8.0,
+ "sliding_window_size": 97,
+ "seed": 11111
+}'
+```
+
+**sky_df_720p_14B - Long 720p Videos:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A journey through different seasons in a forest",
+ "model": "sky_df_720p_14B",
+ "resolution": "1280x720",
+ "video_length": 300,
+ "num_inference_steps": 30,
+ "guidance_scale": 6.0,
+ "flow_shift": 8.0,
+ "sliding_window_size": 121,
+ "seed": 22222
+}'
+```
+
+**ltxv_13B - Fast Long Videos:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A detailed cinematic sequence of a spaceship landing on an alien planet with strange flora and fauna",
+ "model": "ltxv_13B",
+ "resolution": "1280x720",
+ "video_length": 200,
+ "num_inference_steps": 30,
+ "sliding_window_size": 129,
+ "seed": 33333
+}'
+```
+
+**ltxv_13B_distilled - Very Fast:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A magical forest with glowing mushrooms and floating particles",
+ "model": "ltxv_13B_distilled",
+ "resolution": "832x480",
+ "video_length": 150,
+ "num_inference_steps": 20,
+ "seed": 44444
+}'
+```
+
+**hunyuan - Best Quality Text-to-Video:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A serene Japanese garden with koi fish swimming in a pond",
+ "model": "hunyuan",
+ "resolution": "1280x720",
+ "video_length": 97,
+ "num_inference_steps": 30,
+ "guidance_scale": 7.0,
+ "embedded_guidance_scale": 6.0,
+ "flow_shift": 13.0,
+ "seed": 55555
+}'
+```
+
+**moviigen - Cinematic 21:9:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Epic cinematic shot of a hero walking towards a sunset",
+ "model": "moviigen",
+ "resolution": "1920x832",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "guidance_scale": 5.0,
+ "flow_shift": 5.0,
+ "seed": 66666
+}'
+```
+
+### Image-to-Video Models
+
+**i2v - Standard Image-to-Video:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "The person in the image starts walking forward",
+ "model": "i2v",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_image_here",
+ "image_prompt_type": "S",
+ "seed": 77777
+}'
+```
+**i2v_720p - High Resolution Image-to-Video:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "The landscape comes alive with gentle movement",
+ "model": "i2v_720p",
+ "resolution": "1280x720",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_image_here",
+ "guidance_scale": 5.0,
+ "flow_shift": 7.0,
+ "seed": 88888
+}'
+```
+
+**hunyuan_i2v - Hunyuan Image-to-Video:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Subtle animation bringing the scene to life",
+ "model": "hunyuan_i2v",
+ "resolution": "1280x720",
+ "video_length": 97,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_image_here",
+ "guidance_scale": 7.0,
+ "embedded_guidance_scale": 6.0,
+ "seed": 99999
+}'
+```
+
+**fun_inp - Image-to-Video with End Frame:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Smooth transformation from start to end state",
+ "model": "fun_inp",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_start_image",
+ "image_end": "base64_encoded_end_image",
+ "image_prompt_type": "SE",
+ "seed": 111111
+}'
+```
+
+**fun_inp_1.3B - Faster End Frame Support:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Creative morphing between two different scenes",
+ "model": "fun_inp_1.3B",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 25,
+ "image_start": "base64_encoded_start_image",
+ "image_end": "base64_encoded_end_image",
+ "seed": 222222
+}'
+```
+
+**flf2v_720p - Official Start/End Frame Model:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Professional transition between keyframes",
+ "model": "flf2v_720p",
+ "resolution": "1280x720",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_start_image",
+ "image_end": "base64_encoded_end_image",
+ "seed": 333333
+}'
+```
+
+**fantasy - Audio-Driven Animation:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Character speaking with natural lip sync",
+ "model": "fantasy",
+ "resolution": "1280x720",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_start": "base64_encoded_character_image",
+ "audio_guide": "/path/to/audio.wav",
+ "audio_guidance_scale": 5.0,
+ "seed": 444444
+}'
+```
+
+### ControlNet Models
+
+**vace_14B - Advanced Control:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "A dancer performing with precise movements",
+ "model": "vace_14B",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "video_prompt_type": "PV",
+ "video_guide": "/path/to/pose_video.mp4",
+ "image_refs": ["base64_encoded_reference_image"],
+ "seed": 555555
+}'
+```
+
+**vace_1.3B - Faster Control:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Person walking with controlled movement",
+ "model": "vace_1.3B",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 25,
+ "video_prompt_type": "I",
+ "image_refs": ["base64_encoded_reference_image"],
+ "seed": 666666
+}'
+```
+
+**phantom_14B - Object/Person Transfer:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "The person appears in a new magical environment",
+ "model": "phantom_14B",
+ "resolution": "1280x720",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "image_refs": ["base64_encoded_person_image"],
+ "guidance_scale": 7.5,
+ "flow_shift": 5.0,
+ "remove_background_images_ref": 1,
+ "seed": 777777
+}'
+```
+
+**phantom_1.3B - Faster Object Transfer:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Object floating in a dreamy landscape",
+ "model": "phantom_1.3B",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 25,
+ "image_refs": ["base64_encoded_object_image"],
+ "guidance_scale": 7.5,
+ "flow_shift": 5.0,
+ "seed": 888888
+}'
+```
+
+**hunyuan_custom - Identity Preservation:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "The person is walking in a beautiful garden",
+ "model": "hunyuan_custom",
+ "resolution": "1280x720",
+ "video_length": 97,
+ "num_inference_steps": 30,
+ "image_refs": ["base64_encoded_person_image"],
+ "guidance_scale": 7.5,
+ "flow_shift": 13.0,
+ "seed": 999999
+}'
+```
+
+**hunyuan_avatar - Audio-Driven Avatar:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Person speaking naturally with audio sync",
+ "model": "hunyuan_avatar",
+ "resolution": "1280x720",
+ "video_length": 129,
+ "num_inference_steps": 30,
+ "image_refs": ["base64_encoded_person_image"],
+ "audio_guide": "/path/to/speech.wav",
+ "guidance_scale": 7.5,
+ "flow_shift": 5.0,
+ "tea_cache_start_step_perc": 25,
+ "seed": 101010
+}'
+```
+
+### Utility Models
+
+**recam_1.3B - Camera Movement Replay:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Recreate camera movement with new content",
+ "model": "recam_1.3B",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "video_source": "/path/to/source_video.mp4",
+ "model_mode": 5,
+ "seed": 121212
+}'
+```
+
+### Advanced Examples with Multiple Parameters
+
+**Long Video with LoRA:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Epic fantasy battle scene with dragons",
+ "negative_prompt": "blurry, low quality, distorted",
+ "model": "sky_df_14B",
+ "resolution": "960x544",
+ "video_length": 400,
+ "num_inference_steps": 35,
+ "guidance_scale": 6.5,
+ "flow_shift": 8.0,
+ "sliding_window_size": 97,
+ "sliding_window_overlap": 17,
+ "use_causvid_lora": true,
+ "lora_name": "fantasy_style",
+ "tea_cache_setting": 1.5,
+ "RIFLEx_setting": 1,
+ "seed": 131313
+}'
+```
+
+**Complex Vace Control with Inpainting:**
+```bash
+python add_task.py --type video_generation --params '{
+ "prompt": "Character dancing in a new environment",
+ "model": "vace_14B",
+ "resolution": "832x480",
+ "video_length": 81,
+ "num_inference_steps": 30,
+ "video_prompt_type": "MV",
+ "video_guide": "/path/to/control_video.mp4",
+ "video_mask": "/path/to/mask_video.mp4",
+ "image_refs": ["base64_encoded_character"],
+ "keep_frames_video_guide": "1:40",
+ "guidance_scale": 6.0,
+ "flow_shift": 5.0,
+ "seed": 141414
+}'
+```
\ No newline at end of file
diff --git a/STRUCTURE.md b/STRUCTURE.md
index f4b0b3052..a4f00d577 100644
--- a/STRUCTURE.md
+++ b/STRUCTURE.md
@@ -1,32 +1,171 @@
+# Headless-Wan2GP Project Structure
+
+## Recent Updates (January 2025)
+
+### 🚀 **Major Architecture Improvements**
+- **✅ Complete Edge Function Migration**: Eliminated all RPC dependencies, now using pure Supabase Edge Functions
+- **✅ Dual Authentication System**: Perfect Service Key (worker) vs PAT (individual user) authentication
+- **✅ Worker Management**: Auto-creation system for worker IDs with proper constraint handling
+- **✅ Storage Integration**: Full Supabase storage upload/download functionality
+- **✅ Test Coverage**: Comprehensive test suite with 95.5% success rate (21/22 tests passing)
+
+### 🧹 **Repository Cleanup & Organization**
+- **Organized tests:** Moved comprehensive test suite to `tests/` directory
+- **Removed debug files:** Eliminated temporary videos, obsolete test scripts, and debug utilities
+- **Cleaned documentation:** Removed unnecessary .md files, kept essential STRUCTURE.md
+- **Removed SQL migrations:** Eliminated one-time migration files after successful deployment
+- **Streamlined codebase:** Production-ready components only
+
+## Core Architecture
+
+### **Database Operations (`source/db_operations.py`)**
+- **Pure Edge Function Integration**: All database operations via Supabase Edge Functions
+- **Dual Authentication**: Service role keys for workers, PATs for individual users
+- **Storage Management**: Upload/download to `image_uploads` bucket
+- **Worker ID Handling**: Automatic creation and constraint management
+
+### **Edge Functions (`supabase/functions/`)**
+1. **`create-task/`** - Task creation with RLS enforcement
+2. **`claim-next-task/`** - Atomic task claiming with dependency checking
+3. **`complete-task/`** - Task completion with file upload
+4. **`update-task-status/`** - Status updates (In Progress, Failed)
+5. **`get-predecessor-output/`** - Dependency chain resolution
+6. **`get-completed-segments/`** - Segment collection for stitching
+
+### **Authentication Architecture**
+- **Service Role Path**: Uses `worker_id` for machine/process tracking
+- **User/PAT Path**: Clean task claiming without worker complexity
+- **RLS Enforcement**: Row-Level Security via Edge Functions
+- **Token Resolution**: PAT lookup via `user_api_tokens` table
+
+### **Worker Management (`fix_worker_issue.sql`)**
+- Auto-creation trigger for new worker IDs
+- Backfill existing workers from tasks
+- Specific worker ID support: `gpu-20250723_221138-afa8403b`
+- Constraint validation and foreign key management
+
# Project Structure
```
-├── steerable_motion.py
+├── add_task.py
+├── generate_test_tasks.py
├── headless.py
-├── sm_functions/
+├── test_supabase_headless.py # NEW: Test script for Supabase functionality
+├── SUPABASE_SETUP.md # NEW: Setup guide for Supabase mode
+├── source/
│ ├── __init__.py
│ ├── common_utils.py
-│ ├── travel_between_images.py
-│ └── different_pose.py
+│ ├── db_operations.py
+│ ├── specialized_handlers.py
+│ ├── video_utils.py
+│ ├── wgp_utils.py
+│ └── sm_functions/
+│ ├── __init__.py
+│ ├── travel_between_images.py
+│ ├── different_perspective.py
+│ └── single_image.py
+├── tasks/ # Task specifications
+│ └── HEADLESS_SUPABASE_TASK.md # NEW: Supabase implementation spec
+├── supabase/
+│ └── functions/
+│ ├── complete_task/ # Edge Function: uploads file & marks task complete
+│ ├── create_task/ # NEW Edge Function: queues task from client
+│ ├── claim_next_task/ # NEW Edge Function: claims next task (service-role → any, user → own only)
+│ ├── get_predecessor_output/ # NEW Edge Function: gets task dependency and its output in single call
+│ └── get-completed-segments/ # NEW Edge Function: fetches completed travel_segment outputs for a run_id, bypassing RLS
+├── logs/ # runtime logs (git-ignored)
+├── outputs/ # generated videos/images (git-ignored)
+├── samples/ # example inputs for docs & tests
+├── tests/ # pytest suite
+├── test_outputs/ # artefacts written by tests (git-ignored)
├── Wan2GP/ ← third-party video-generation engine (keep high-level only)
└── STRUCTURE.md (this file)
```
## Top-level scripts
-**steerable_motion.py** – CLI orchestrator. Parses user arguments for the two primary tasks, ensures the SQLite `tasks` table exists, sets global/debug flags and then delegates to task handlers in `sm_functions`.
+* **headless.py** – Headless service that polls the `tasks` database, claims work, and drives the Wan2GP generator (`wgp.py`). Includes extra handlers for OpenPose and RIFE interpolation tasks and can upload outputs to Supabase storage. **NEW**: Now supports both SQLite and Supabase backends via `--db-type` flag.
+* **add_task.py** – Lightweight CLI helper to queue a single new task into SQLite/Supabase. Accepts a JSON payload (or file) and inserts it into the `tasks` table.
+* **generate_test_tasks.py** – Developer utility that back-fills the database with synthetic images/prompts for integration testing and local benchmarking.
+* **tests/test_travel_workflow_db_edge_functions.py** – **NEW**: Comprehensive test script to verify Supabase Edge Functions, authentication, and database operations for the headless worker.
+
+## Supabase Upload System
+
+**NEW**: All task types now support automatic upload to Supabase Storage when configured:
+
+### How it works
+* **Local-first**: Files are always saved locally first for reliability
+* **Conditional upload**: If Supabase is configured, files are uploaded to the `image_uploads` bucket
+* **Consistent API**: All task handlers use the same two functions:
+ * `prepare_output_path_with_upload()` - Sets up local path and provisional DB location
+ * `upload_and_get_final_output_location()` - Handles upload and returns final URL/path for DB
+
+### Task type coverage
+* **single_image**: Generated images → Supabase bucket with public URLs
+* **travel_stitch**: Final stitched videos → Supabase bucket
+* **different_perspective**: Final posed images → Supabase bucket
+* **Standard WGP tasks**: All video outputs → Supabase bucket
+* **Specialized handlers**: OpenPose masks, RIFE interpolations, etc. → Supabase bucket
+
+### Database behavior
+* **SQLite mode**: `output_location` contains relative paths (e.g., `files/video.mp4`)
+* **Supabase mode**: `output_location` contains public URLs (e.g., `https://xyz.supabase.co/storage/v1/object/public/image_uploads/task_123/video.mp4`)
+* **Object naming**: Files stored as `{task_id}/{filename}` for collision-free organization
+
+## source/ package
+
+This is the main application package.
+
+* **common_utils.py** – Reusable helpers (file downloads, ffmpeg helpers, MediaPipe keypoint interpolation, debug utilities, etc.). **UPDATED**: Now includes generalized Supabase upload functions (`prepare_output_path_with_upload`, `upload_and_get_final_output_location`) used by all task types.
+* **db_operations.py** – Handles all database interactions for both SQLite and Supabase. **UPDATED**: Now includes Supabase client initialization, Edge Function integration, and automatic backend selection based on `DB_TYPE`.
+* **specialized_handlers.py** – Contains handlers for specific, non-standard tasks like OpenPose generation and RIFE interpolation. **UPDATED**: Uses Supabase-compatible upload functions for all outputs.
+* **video_utils.py** – Provides utilities for video manipulation like cross-fading, frame extraction, and color matching.
+* **wgp_utils.py** – Thin wrapper around `Wan2GP.wgp` that standardises parameter names, handles LoRA quirks (e.g. CausVid, LightI2X), and exposes the single `generate_single_video` helper used by every task handler. **UPDATED**: Now includes comprehensive debugging throughout the generation pipeline with detailed frame count validation.
+
+### source/sm_functions/ sub-package
+
+Task-specific wrappers around the bulky upstream logic. These are imported by `headless.py` (and potentially by notebooks/unit tests) without dragging in the interactive Gradio UI shipped with Wan2GP. **UPDATED**: All task handlers now use generalized Supabase upload functions for consistent output handling.
+
+* **travel_between_images.py** – Implements the segment-by-segment interpolation pipeline between multiple anchor images. Builds guide videos, queues generation tasks, stitches outputs. **UPDATED**: Final stitched videos are uploaded to Supabase when configured. **NEW**: Extensive debugging system with `debug_video_analysis()` function that tracks frame counts, file sizes, and processing steps throughout the entire orchestrator → segments → stitching pipeline.
+* **different_perspective.py** – Generates a new perspective for a single image using an OpenPose or depth-driven guide video plus optional RIFE interpolation for smoothness. **UPDATED**: Final posed images are uploaded to Supabase when configured.
+* **single_image.py** – Minimal handler for one-off image-to-video generation without travel or pose manipulation. **UPDATED**: Generated images are uploaded to Supabase when configured.
+* **magic_edit.py** – **NEW**: Processes images through Replicate's black-forest-labs/flux-kontext-dev-lora model for scene transformations. Supports conditional InScene LoRA usage via `in_scene` parameter (true for scene consistency, false for creative freedom). Integrates with Supabase storage for output handling.
+* **__init__.py** – Re-exports public APIs (`run_travel_between_images_task`, `run_single_image_task`, `run_different_perspective_task`) and common utilities for convenient importing.
+
+## Additional runtime artefacts & folders
-**headless.py** – Headless server that continuously polls the `tasks` database, claims work, and drives the Wan2GP generator (`wgp.py`). Includes extra handlers for OpenPose and RIFE interpolation tasks and can upload outputs to Supabase storage.
+* **logs/** – Rolling log files captured by `headless.py` and unit tests. The directory is git-ignored.
+* **outputs/** – Default location for final video/image results when not explicitly overridden by a task payload.
+* **samples/** – A handful of small images shipped inside the repo that are referenced in the README and tests.
+* **tests/** – Pytest-based regression and smoke tests covering both low-level helpers and full task workflows.
+* **test_outputs/** – Artefacts produced by the test-suite; kept out of version control via `.gitignore`.
+* **tasks.db** – SQLite database created on-demand by the orchestrator to track queued, running, and completed tasks (SQLite mode only).
-## sm_functions/ package
+## Database Backends
-A light-weight, testable wrapper around the bulky `steerable_motion.py` logic. It holds:
+**NEW**: The system now supports two database backends:
-* **common_utils.py** – Reusable helpers (DB access, ffmpeg helpers, MediaPipe keypoint interpolation, debug utilities, etc.)
-* **travel_between_images.py** – Implements the segment-by-segment interpolation pipeline between multiple anchor images. Builds guide videos, queues generation tasks, stitches outputs.
-* **different_pose.py** – Generates a new pose for a single image using an OpenPose-driven guide video plus optional RIFE interpolation for smoothness.
-* **__init__.py** – Re-exports public APIs (`run_travel_between_images_task`, `run_different_pose_task`) and common utilities for convenient importing.
+### SQLite (Default)
+* Local file-based database (`tasks.db`)
+* No authentication required
+* Good for single-machine deployments
+* Files stored in `public/files/`
+
+### Supabase
+* Cloud-hosted PostgreSQL via Supabase
+* Supports Row-Level Security (RLS)
+* Enable with: `--db-type supabase --supabase-url --supabase-access-token `
+* Two token types:
+ * **User JWT**: Only processes tasks owned by that user
+ * **Service-role key**: Processes all tasks (bypasses RLS)
+* Files can be uploaded to Supabase Storage (in development)
+* Uses Edge Functions for database operations to handle RLS properly:
+ * `claim_next_task/` - Claims tasks with dependency checking
+ * `get_predecessor_output/` - Gets task dependencies and outputs
+ * `complete_task/` - Uploads files and marks tasks complete
+ * `create_task/` - Creates new tasks
+* Python code uses Edge Functions for Supabase, direct queries for SQLite
## Wan2GP/
@@ -56,9 +195,103 @@ The submodule is currently pinned to commit `6706709` ("optimization for i2v wit
## Runtime artefacts
-* **tasks.db** – SQLite database created on-demand by the orchestrator/server to track queued, running, and completed tasks.
-* Intermediate segment folders under `steerable_motion_output/` (or user-provided directory) – automatically cleaned unless `--debug`/`--skip_cleanup` is set.
+* **tasks.db** – SQLite database created on-demand by the orchestrator/server to track queued, running, and completed tasks (SQLite mode only).
+* **public/files/** – For SQLite mode, all final video outputs are saved directly here with descriptive filenames (e.g., `{run_id}_seg00_output.mp4`, `{run_id}_final.mp4`). No nested subdirectories are created.
+* **outputs/** – For non-SQLite modes or when explicitly configured, videos are saved here with task-specific subdirectories.
+
+## End-to-End task lifecycle (1-minute read)
+
+1. **Task injection** – A CLI, API, or test script calls `add_task.py`, which inserts a new row into the `tasks` table (SQLite or Supabase). Payload JSON is stored in `params`, `status` is set to `Queued`.
+2. **Worker pickup** – `headless.py` runs in a loop, atomically updates a `Queued` row to `In Progress`, and inspects `task_type` to choose the correct handler.
+3. **Handler execution**
+ * Standard tasks live in `source/sm_functions/…` (see table below).
+ * Special one-offs (OpenPose, RIFE, etc.) live in `specialized_handlers.py`.
+ * Handlers may queue **sub-tasks** (e.g. travel → N segments + 1 stitch) by inserting new rows with `dependant_on` set, forming a DAG.
+4. **Video generation** – Every handler eventually calls `wgp_utils.generate_single_video` which wraps **Wan2GP/wgp.py** and returns a path to the rendered MP4.
+5. **Post-processing** – Optional saturation / brightness / colour-match (`video_utils.py`) or upscaling tasks.
+6. **DB update** – Handler stores `output_location` (relative in SQLite, absolute or URL in Supabase) and marks the row `Complete` (or `Failed`). Dependants are now eligible to start.
+7. **Cleanup** – Intermediate folders are deleted unless `debug_mode_enabled` or `skip_cleanup_enabled` flags are set in the payload.
+
+## Quick task-to-file reference
+
+| Task type / sub-task | Entrypoint function | File |
+|----------------------|---------------------|------|
+| Travel orchestrator | `_handle_travel_orchestrator_task` | `sm_functions/travel_between_images.py` |
+| Travel segment | `_handle_travel_segment_task` | " " |
+| Travel stitch | `_handle_travel_stitch_task` | " " |
+| Single image video | `run_single_image_task` | `sm_functions/single_image.py` |
+| Different perspective | `run_different_perspective_task` | `sm_functions/different_perspective.py` |
+| Magic edit | `_handle_magic_edit_task` | `sm_functions/magic_edit.py` |
+| OpenPose mask video | `handle_openpose_task` | `specialized_handlers.py` |
+| RIFE interpolation | `handle_rife_task` | `specialized_handlers.py` |
+
+All of the above eventually call `wgp_utils.generate_single_video`, which is the single **shared** bridge into Wan2GP.
+
+## Database cheat-sheet
+
+Column | Purpose
+-------|---------
+`id` | UUID primary key (task_id)
+`task_type` | e.g. `travel_segment`, `wgp`, `travel_stitch`
+`dependant_on` | Optional FK forming execution DAG
+`params` | JSON payload saved by the enqueuer
+`status` | `Queued` → `In Progress` → `Complete`/`Failed`
+`output_location` | Where the final artefact lives (string)
+`updated_at` | Heartbeat & ordering
+`project_id` | Links to project (required for Supabase RLS)
+
+SQLite keeps the DB at `tasks.db`; Supabase uses the same columns with RLS policies.
+
+## Debugging System
+
+**NEW**: Comprehensive debugging system for video generation pipeline with detailed frame count tracking and validation:
+
+### Debug Functions
+* **`debug_video_analysis()`** – Analyzes any video file and reports frame count, FPS, duration, file size with clear labeling
+* **Frame count validation** – Compares expected vs actual frame counts at every processing step with ⚠️ warnings for mismatches
+* **Processing step tracking** – Logs success/failure of each chaining step (saturation, brightness, color matching, banner overlay)
+
+### Debug Output Categories
+* **`[FRAME_DEBUG]`** – Orchestrator frame quantization and overlap calculations
+* **`[SEGMENT_DEBUG]`** – Individual segment processing parameters and frame analysis
+* **`[WGP_DEBUG]`** – WGP generation parameters, results, and frame count validation
+* **`[CHAIN_DEBUG]`** – Post-processing chain (saturation, brightness, color matching) with step-by-step analysis
+* **`[STITCH_DEBUG]`** – Path resolution, video collection, and cross-fade analysis
+* **`[CRITICAL_DEBUG]`** – Critical stitching calculations and frame count summaries
+* **`[STITCH_FINAL_ANALYSIS]`** – Complete final video analysis with expected vs actual comparisons
+
+### Key Features
+* **Video analysis at every step** – Frame count, FPS, duration, file size tracked throughout pipeline
+* **Path resolution debugging** – Detailed logging of SQLite-relative, absolute, and URL path handling
+* **Cross-fade calculation verification** – Step-by-step analysis of overlap processing and frame arithmetic
+* **Mismatch highlighting** – Clear warnings when frame counts don't match expectations
+* **Processing chain validation** – Success/failure tracking for each post-processing step
+
+This debugging system provides comprehensive visibility into the video generation pipeline to identify exactly where frame counts change and why final outputs might have unexpected lengths.
+
+## LoRA Support
+
+### Special LoRA Flags
+
+* **`use_causvid_lora`** – Enables CausVid LoRA with 9 steps, guidance 1.0, flow-shift 1.0. Auto-downloads from HuggingFace if missing.
+* **`use_lighti2x_lora`** – Enables LightI2X LoRA with 5 steps, guidance 1.0, flow-shift 5.0, Tea Cache disabled. Auto-downloads from HuggingFace if missing.
+
+Both flags automatically configure optimal generation parameters and handle LoRA downloads/activation.
+
+## Environment & config knobs (non-exhaustive)
+
+Variable / flag | Effect
+----------------|-------
+`SUPABASE_URL / SUPABASE_SERVICE_KEY` | Used for Supabase connection (if not provided via CLI).
+`POSTGRES_TABLE_NAME` | Table name for Supabase (default: `tasks`).
+`SUPABASE_VIDEO_BUCKET` | Storage bucket name for video and image uploads.
+`WAN2GP_CACHE` | Where Wan2GP caches model weights.
+`--debug` | Prevents cleanup of temp folders, extra logs.
+`--skip_cleanup` | Keeps all intermediate artefacts even outside debug.
+`--db-type` | Choose between `sqlite` (default) or `supabase`.
+`--supabase-url` | Supabase project URL (required for Supabase mode).
+`--supabase-access-token` | JWT token or service-role key for authentication.
---
-This document is auto-generated to give newcomers a concise map of the codebase; update it when new modules or packages are added.
\ No newline at end of file
+Keep this file **brief** – for in-depth developer docs see the `docs/` folder and inline module docstrings.
\ No newline at end of file
diff --git a/Wan2GP b/Wan2GP
index 7670af961..026c2b0cb 160000
--- a/Wan2GP
+++ b/Wan2GP
@@ -1 +1 @@
-Subproject commit 7670af9610c860247d7086c3501d6264b9d86a74
+Subproject commit 026c2b0cbb30cc5fadce937ce56364d20efb8d94
diff --git a/add_task.py b/add_task.py
new file mode 100644
index 000000000..e1c7a27dd
--- /dev/null
+++ b/add_task.py
@@ -0,0 +1,73 @@
+import argparse
+import json
+import sys
+from pathlib import Path
+
+# Ensure project source directory is on the import path
+proj_root = Path(__file__).resolve().parent
+if str(proj_root) not in sys.path:
+ sys.path.insert(0, str(proj_root))
+
+from source import db_operations as db_ops
+from source.common_utils import generate_unique_task_id as sm_generate_unique_task_id
+
+
+def _load_params(param_arg: str) -> dict:
+ """Load params from a JSON string or from a @file reference."""
+ if param_arg.startswith("@"):
+ file_path = Path(param_arg[1:]).expanduser()
+ if not file_path.exists():
+ raise FileNotFoundError(f"Params file not found: {file_path}")
+ with open(file_path, "r", encoding="utf-8") as fp:
+ return json.load(fp)
+ # Otherwise, treat as literal JSON string
+ return json.loads(param_arg)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ "add_task – enqueue a task for the Wan2GP headless server"
+ )
+ parser.add_argument(
+ "--type",
+ required=True,
+ help="Task type string (e.g. travel_orchestrator, travel_segment, generate_openpose, …)",
+ )
+ parser.add_argument(
+ "--params",
+ required=True,
+ help="JSON string with task payload OR @",
+ )
+ parser.add_argument(
+ "--dependant-on",
+ dest="dependant_on",
+ default=None,
+ help="Optional task_id that this new task depends on.",
+ )
+ args = parser.parse_args()
+
+ try:
+ payload_dict = _load_params(args.params)
+ except Exception as e:
+ print(f"[ERROR] Could not parse --params: {e}")
+ sys.exit(1)
+
+ # Auto-generate task_id if needed
+ if "task_id" not in payload_dict or not payload_dict["task_id"]:
+ payload_dict["task_id"] = sm_generate_unique_task_id(f"{args.type[:8]}_")
+ print(f"[INFO] task_id not supplied – generated: {payload_dict['task_id']}")
+
+ try:
+ db_ops.add_task_to_db(
+ task_payload=payload_dict,
+ task_type_str=args.type,
+ dependant_on=args.dependant_on,
+ )
+ print(f"[SUCCESS] Task {payload_dict['task_id']} ({args.type}) enqueued.")
+ except Exception as e_db:
+ print(f"[ERROR] Failed to enqueue task: {e_db}")
+ sys.exit(2)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/check_loras.py b/check_loras.py
new file mode 100644
index 000000000..2e4daa04c
--- /dev/null
+++ b/check_loras.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+"""
+LoRA File Integrity Checker
+
+This script checks LoRA files in common directories for corruption or size issues.
+It can also clean up corrupted files if requested.
+
+Usage:
+ python check_loras.py # Check all LoRA directories
+ python check_loras.py --fix # Check and remove corrupted files
+ python check_loras.py --dir /path/to/loras # Check specific directory
+"""
+
+import argparse
+import sys
+from pathlib import Path
+
+# Add source directory to path to import common_utils
+sys.path.append(str(Path(__file__).parent / "source"))
+from common_utils import check_loras_in_directory
+
+def main():
+ parser = argparse.ArgumentParser(description="Check LoRA file integrity")
+ parser.add_argument("--dir", type=str, help="Specific directory to check")
+ parser.add_argument("--fix", action="store_true", help="Remove corrupted files")
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
+
+ args = parser.parse_args()
+
+ # Default LoRA directories to check
+ default_dirs = [
+ "Wan2GP/loras",
+ "Wan2GP/loras_hunyuan",
+ "Wan2GP/loras_hunyuan_i2v",
+ "Wan2GP/loras_i2v",
+ "Wan2GP/loras_ltxv"
+ ]
+
+ if args.dir:
+ dirs_to_check = [args.dir]
+ else:
+ dirs_to_check = default_dirs
+
+ total_checked = 0
+ total_valid = 0
+ total_invalid = 0
+
+ print("🔍 LoRA File Integrity Check")
+ print("=" * 50)
+
+ for lora_dir in dirs_to_check:
+ dir_path = Path(lora_dir)
+ print(f"\n📁 Checking directory: {dir_path}")
+
+ if not dir_path.exists():
+ print(f" ⚠️ Directory not found: {dir_path}")
+ continue
+
+ results = check_loras_in_directory(dir_path, fix_issues=args.fix)
+
+ if "error" in results:
+ print(f" ❌ Error: {results['error']}")
+ continue
+
+ total_checked += results["total_files"]
+ total_valid += results["valid_files"]
+ total_invalid += results["invalid_files"]
+
+ print(f" 📊 Files found: {results['total_files']}")
+ print(f" ✅ Valid: {results['valid_files']}")
+ print(f" ❌ Invalid: {results['invalid_files']}")
+
+ if results["invalid_files"] > 0:
+ print(f" 🚨 Issues found:")
+ for issue in results["issues"]:
+ print(f" {issue}")
+
+ if args.verbose:
+ print(f" 📋 Detailed results:")
+ for summary_line in results["summary"]:
+ print(f" {summary_line}")
+
+ print("\n" + "=" * 50)
+ print(f"📈 Summary:")
+ print(f" Total files checked: {total_checked}")
+ print(f" Valid files: {total_valid}")
+ print(f" Invalid files: {total_invalid}")
+
+ if total_invalid > 0:
+ print(f"\n💡 Found {total_invalid} corrupted files.")
+ if not args.fix:
+ print(" Run with --fix to automatically remove corrupted files.")
+ else:
+ print(" Corrupted files have been removed.")
+ sys.exit(1)
+ else:
+ print("\n🎉 All LoRA files are valid!")
+ sys.exit(0)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/env.txt b/env.txt
deleted file mode 100644
index ec2b77b4c..000000000
--- a/env.txt
+++ /dev/null
@@ -1,7 +0,0 @@
- DB_TYPE=supabase
- POSTGRES_TABLE_NAME="tasks" # Desired table name for tasks, used in RPC calls
-
- # For Supabase interactions (DB via RPC and Storage)
- SUPABASE_URL="https://ddbobialzdjkzainyqgb.supabase.co"
- SUPABASE_SERVICE_KEY=""
- SUPABASE_VIDEO_BUCKET="videos" # Your Supabase storage bucket name
\ No newline at end of file
diff --git a/explanations/WanEncodingExplanation.md b/explanations/WanEncodingExplanation.md
deleted file mode 100644
index c309a4c55..000000000
--- a/explanations/WanEncodingExplanation.md
+++ /dev/null
@@ -1,263 +0,0 @@
-# VACE Encoding in Wan2GP Explained
-
-This document explains how the VACE (Video AutoEncoder ControlNet-like) encoding is applied within the Wan2GP framework. VACE enables ControlNet-like functionalities for video generation, such as video-to-video and reference-image-to-video tasks.
-
-## Overview
-
-The VACE encoding process involves several key components and steps:
-
-1. **Initialization**: When a VACE model is selected, specific VACE components are initialized and integrated into the main `WanModel`.
-2. **Input Preprocessing**: Video frames, reference images, and masks are preprocessed.
-3. **VAE Encoding**: The preprocessed inputs are encoded into a latent representation using the `WanVAE`.
- * Frames are encoded by `WanVAE`. If masks are present, "inactive" and "reactive" parts of the frames are encoded separately and then combined.
- * Reference images are also encoded by `WanVAE`.
- * Masks are processed (reshaped and interpolated) to match the VAE's latent dimensions.
-4. **VACE Context Creation**: The encoded frames, reference images, and masks are combined to form the `vace_context`.
-5. **Integration into `WanModel`**: The `vace_context` is injected into the `WanModel`'s attention blocks at specified layers during the diffusion process.
-
-## Detailed Steps and File References
-
-### 1. Initialization and Model Adaptation
-
-* **File**: `Wan2GP/wan/text2video.py`
-* **Class**: `WanT2V`
- * In the `__init__` method (around lines 42-98):
- * If the `model_filename` contains "Vace", a `VaceVideoProcessor` is initialized (line 90).
- * The crucial step is `self.adapt_vace_model()` (line 97).
- * **`adapt_vace_model()` method (lines 569-576)**:
- * This method is responsible for integrating VACE functionality into the main `WanModel`.
- * It iterates through `model.vace_layers_mapping` (which is defined in `Wan2GP/wan/modules/model.py` during `WanModel` initialization if `vace_layers` are specified).
- * For each VACE layer, it takes a pre-initialized VACE-specific attention block from `model.vace_blocks` (these are `VaceWanAttentionBlock` instances) and assigns it as an attribute (named `vace`) to the corresponding standard `WanAttentionBlock` in `model.blocks`.
- * Example: `setattr(target, "vace", module)`, where `target` is a `WanAttentionBlock` and `module` is a `VaceWanAttentionBlock`.
- * After this, the original `model.vace_blocks` attribute is deleted from the model.
-
-### 2. VACE Input Encoding
-
-* **File**: `Wan2GP/wan/text2video.py`
-* **Class**: `WanT2V`
- * **`vace_encode_frames(frames, ref_images, masks, ...)` method (lines 100-130)**:
- * This method encodes the input video frames and reference images.
- * It uses `self.vae.encode(...)` for the actual encoding. `self.vae` is an instance of `WanVAE` (from `Wan2GP/wan/modules/vae.py`).
- * **If masks are provided**:
- * Frames are split into `inactive` parts (`i * (1 - m)`) and `reactive` parts (`i * m`).
- * Both parts are encoded separately by `self.vae.encode()`.
- * The resulting latents are concatenated: `latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]`. This doubles the channel dimension of the latents for masked regions.
- * **If reference images (`ref_images`) are provided**:
- * Reference images are encoded using `self.vae.encode()`.
- * If masks were also provided, the reference latents are padded with zeros on one half of the channel dimension (`ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]`).
- * The reference latents are concatenated with the frame latents along the temporal dimension (`dim=1`).
- * **`vace_encode_masks(masks, ref_images=None)` method (lines 132-160)**:
- * This method processes the input masks.
- * Masks are reshaped and then interpolated using `F.interpolate(..., mode='nearest-exact')` to match the spatial and temporal dimensions of the VAE's latent space (considering `self.vae_stride`).
- * If reference images are provided, the mask is padded with zeros at the beginning of the temporal dimension to account for the concatenated reference latents.
- * **`vace_latent(z, m)` method (line 162)**:
- * A simple utility that concatenates the latents from `vace_encode_frames` (`z`) and the processed masks from `vace_encode_masks` (`m`) along the channel dimension (`dim=0`).
- * `z` itself (from `vace_encode_frames`) might already have a doubled channel dimension if input masks were used. So, if an input mask was provided, `z` would be `[2*C, T, H, W]` (inactive_latent, reactive_latent) and `m` would be `[C_mask, T, H, W]`. This function seems to concatenate them further, but `z` in the `generate` function (where `vace_latent` is called) is `z0` which is the output of `vace_encode_frames`. If masks were used in `vace_encode_frames`, `z0` already contains both inactive and reactive parts. The `m0` (encoded masks) are concatenated to this `z0`.
- * More accurately, looking at the `generate` function (line 376-377), `z0 = self.vace_encode_frames(...)` and `m0 = self.vace_encode_masks(...)`, then `z = self.vace_latent(z0, m0)`. If `vace_encode_frames` already concatenated inactive/reactive parts, `z0` has shape `[B, 2*channels, T, H, W]`. Then `m0` (the VAE-encoded mask itself) is concatenated, leading to `[B, 2*channels + mask_channels, T, H, W]`. This resulting `z` is the `vace_context`.
-
-### 3. `WanVAE` Encoding Details
-
-* **File**: `Wan2GP/wan/modules/vae.py`
-* **Class**: `WanVAE`
- * The `encode(videos, tile_size, ...)` method (lines 813-823) is a wrapper.
- * It calls `self.model.encode(...)` or `self.model.spatial_tiled_encode(...)`. `self.model` is an instance of `WanVAE_` (note the underscore).
-* **Class**: `WanVAE_`
- * **`encode(x, scale=None, ...)` method (lines 534-569)**:
- * The input video tensor `x` is processed in temporal chunks (first frame, then 4-frame chunks, then last frame).
- * Each chunk goes through `self.encoder` (an `Encoder3d` instance).
- * The outputs are concatenated temporally.
- * A final convolution `self.conv1(out).chunk(2, dim=1)` produces `mu` and `log_var`.
- * Only `mu` is returned (potentially scaled) and used as the latent representation.
-
-### 4. Injection into `WanModel`
-
-* **File**: `Wan2GP/wan/text2video.py`
-* **Class**: `WanT2V`
- * In the `generate()` method (around lines 275-473):
- * If VACE is active (`vace = "Vace" in model_filename`):
- * `z0 = self.vace_encode_frames(...)` (line 376)
- * `m0 = self.vace_encode_masks(...)` (line 377)
- * `z = self.vace_latent(z0, m0)` (line 378). This `z` is the `vace_context`.
- * This `vace_context` is added to the keyword arguments passed to the diffusion model's sampling loop: `kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})` (line 442).
-
-* **File**: `Wan2GP/wan/modules/model.py`
-* **Class**: `WanModel`
- * **`__init__(..., vace_layers=None, vace_in_dim=None, ...)` constructor (lines 656-808)**:
- * If `vace_layers` is provided (which happens if a VACE model is being loaded):
- * `self.vace_layers_mapping` is created.
- * `self.blocks` (list of `WanAttentionBlock`) are recreated, and each `WanAttentionBlock` that corresponds to a VACE layer gets a `block_id` (line 773).
- * `self.vace_blocks` is initialized as a `nn.ModuleList` of `VaceWanAttentionBlock` instances (lines 779-783). These are the blocks that were later assigned using `setattr` in `WanT2V.adapt_vace_model`.
- * `self.vace_patch_embedding = nn.Conv3d(...)` is created (lines 786-788). This layer processes the raw `vace_context` before it's fed into the attention blocks.
- * **`forward(..., vace_context=None, vace_context_scale=1.0, ...)` method (lines 902-1081)**:
- * If `vace_context` is provided (lines 1017-1024):
- * The `vace_context` is first processed by `self.vace_patch_embedding`.
- * This processed context `c` is then packaged into `hints_list = [ [c] for _ in range(len(x_list)) ]`.
- * Inside the loop iterating through `self.blocks` (line 1049, though the loop itself is not shown in this snippet but implied by how `block(x, e, hints=hints_list[i], ...)` is called):
- * Each `block` (which is a `WanAttentionBlock`) receives the `hints` (the processed `vace_context`).
-
-* **File**: `Wan2GP/wan/modules/model.py`
-* **Class**: `WanAttentionBlock`
- * **`forward(..., hints=None, context_scale=1.0, ...)` method (lines 397-499)**:
- * If `self.block_id is not None` (meaning it's a VACE-designated layer) and `hints` are provided (line 417):
- * `hint = self.vace(hints, x, **kwargs)` (line 423 or 425). Here, `self.vace` is the `VaceWanAttentionBlock` instance that was attached during `adapt_vace_model`.
- * The returned `hint` from the `VaceWanAttentionBlock` is added to the main feature map `x`: `x.add_(hint, alpha=context_scale)` (line 497).
-
-* **File**: `Wan2GP/wan/modules/model.py`
-* **Class**: `VaceWanAttentionBlock` (inherits from `WanAttentionBlock`)
- * **`__init__(...)` (lines 504-522)**:
- * Initializes `before_proj` and `after_proj` linear layers with zero weights and biases.
- * **`forward(self, hints, x, **kwargs)` method (lines 526-535)**:
- * This is where the VACE conditioning is directly applied.
- * `c = hints[0]` gets the VACE context features (output of `vace_patch_embedding` from `WanModel.forward`).
- * If it's the first VACE block (`self.block_id == 0`), it applies `self.before_proj` to `c` and then adds the current timestep's feature map `x`: `c += x`.
- * It then calls `super().forward(c, **kwargs)`. This means the (potentially modified) VACE context `c` is processed through a standard `WanAttentionBlock`'s attention and FFN layers.
- * The output of this is `c`.
- * `c_skip = self.after_proj(c)`: The result is passed through another projection.
- * `hints[0] = c`: The result *before* the `after_proj` is stored back into `hints[0]`. This is interesting, as it means subsequent VACE blocks in the *same timestep* would receive the output of the previous VACE block's main processing path, not the skip connection. (Correction: `hints` is a list passed down. `hints[0]` is modified in place. This `c` will be the input `hints[0]` for the *next* `VaceWanAttentionBlock` if multiple are chained directly without an intermediate normal `WanAttentionBlock`. However, the structure `adapt_vace_model` sets up `target.vace = module`, so each `WanAttentionBlock.forward` calls its own `self.vace` instance with the *original* hint from `WanModel.forward`.)
- * The actual returned value is `c_skip`, which is then added to the main path in `WanAttentionBlock.forward`.
-
-## Summary of VACE Context Flow
-
-1. **`WanT2V.generate`**:
- * `input_frames`, `input_masks`, `input_ref_images` -> `vace_encode_frames` & `vace_encode_masks` (using `WanVAE`) -> `z0`, `m0`.
- * `vace_latent(z0, m0)` -> `vace_context`.
-2. **`WanModel.forward`**:
- * `vace_context` -> `self.vace_patch_embedding` -> `c` (processed VACE context).
- * `c` is put into `hints_list`.
-3. **`WanModel`'s loop over `blocks`**:
- * For each `block` (a `WanAttentionBlock`): `block(x, ..., hints=hints_list[...])`.
-4. **`WanAttentionBlock.forward`**:
- * If it's a VACE layer (`self.block_id is not None`):
- * `hint_output = self.vace(hints, x, ...)` (where `self.vace` is a `VaceWanAttentionBlock`).
- * `x = x + hint_output * context_scale`.
-5. **`VaceWanAttentionBlock.forward`**:
- * Receives `hints` (containing processed `vace_context` `c`) and `x` (current features).
- * `c_modified = c` (or `before_proj(c) + x` if first VACE block).
- * `c_processed = super().forward(c_modified, ...)` (pass through standard attention block).
- * `c_skip = self.after_proj(c_processed)`.
- * Returns `c_skip`.
-
-This detailed flow shows how VACE conditions the generation process by injecting its encoded representation of frames, masks, and reference images into specific layers of the main `WanModel` transformer via specialized attention blocks.
-
-## Interaction with `wgp.py` (Gradio UI)
-
-The main application script, `wgp.py`, handles user inputs from the Gradio web interface and orchestrates the video generation process, including the VACE-specific parts.
-
-### 1. UI Input Collection
-
-* **File**: `Wan2GP/wgp.py`
-* **Function**: `generate_video_tab` (around lines 4288-4842, VACE specific UI around 4458-4507)
- * This function defines the Gradio UI elements.
- * For VACE, it creates:
- * `video_prompt_type_video_guide`: A dropdown to select the type of video guidance (None, Pose, Depth, Color, VACE general, VACE with Mask). This determines letters like 'P', 'D', 'C', 'V', 'M' in the `video_prompt_type` string.
- * `video_prompt_type_image_refs`: A dropdown to enable/disable "Inject custom Faces / Objects", adding 'I' to `video_prompt_type`.
- * `video_guide`: A `gr.Video` component for the control video.
- * `keep_frames_video_guide`: A `gr.Text` input to specify frames to keep/mask from the `video_guide`.
- * `image_refs`: A `gr.Gallery` for reference images.
- * `remove_background_image_ref`: A `gr.Checkbox` for background removal on reference images.
- * `video_mask`: A `gr.Video` component for an explicit video mask (for inpainting/outpainting).
- * The choices made in these UI elements are collated into variables that are then passed to the backend processing.
-
-### 2. Input Validation and Task Creation
-
-* **File**: `Wan2GP/wgp.py`
-* **Function**: `process_prompt_and_add_tasks` (around lines 129-458, VACE specific logic around 301-368)
- * This function is triggered when a user adds a generation task.
- * If the selected model is a VACE model (checks if "Vace" is in `model_filename`):
- * It retrieves `video_prompt_type`, `image_refs`, `video_guide`, `video_mask`, etc., from the current UI state.
- * It validates that necessary inputs are provided based on the selected `video_prompt_type` (e.g., if 'I' is selected, `image_refs` must exist).
- * Reference images from the gallery are converted using `convert_image`.
- * These validated and potentially pre-processed inputs are then used to populate the arguments for `add_video_task`.
-
-### 3. Video Generation Orchestration
-
-* **File**: `Wan2GP/wgp.py`
-* **Function**: `generate_video` (lines 2648-3276)
- * This is the main worker function that performs video generation. It receives all parameters, including those for VACE, which were set up by the UI and `process_prompt_and_add_tasks`.
- * It sets a boolean `vace = "Vace" in model_filename` (line 2805).
- * **Background Removal for Image References (lines 2818-2824)**: If `image_refs` are present and VACE is active, it calls `wan.utils.utils.resize_and_remove_background` if the corresponding UI checkbox is ticked.
- * **VACE Input Preparation (lines 2986-3028, within a loop for sliding window processing)**:
- * If `vace` is true:
- * It makes copies of `image_refs`, `video_guide`, `video_mask` (as the `prepare_source` method might modify them).
- * **Control Video Preprocessing**: If `video_prompt_type` indicates 'P' (Pose), 'D' (Depth), or 'G' (Grayscale), it calls `preprocess_video` (defined in `wgp.py` around line 2540) to apply these effects to the `video_guide_copy`.
- * `parse_keep_frames_video_guide` processes the string specifying which frames to keep from the control video.
- * **`wan_model.prepare_source` Call**:
- ```python
- src_video, src_mask, src_ref_images = wan_model.prepare_source(
- [video_guide_copy],
- [video_mask_copy],
- [image_refs_copy],
- video_length, # Current window's length
- image_size=image_size,
- device="cpu",
- original_video="O" in video_prompt_type, # "O" for original video / alternate ending
- keep_frames=keep_frames_parsed,
- start_frame=guide_start_frame,
- pre_src_video=[pre_video_guide], # For sliding window continuity
- fit_into_canvas=(fit_canvas == 1)
- )
- ```
- - `wan_model` here is an instance of `WanT2V` (from `Wan2GP/wan/text2video.py`).
- - The `prepare_source` method (defined in `Wan2GP/wan/text2video.py`, lines 164-250) is crucial. It takes the raw video paths/data and converts them into the actual PyTorch tensors. It handles:
- - Loading video frames.
- - Resizing them to the target `image_size`.
- - Applying the `keep_frames` logic: frames not in `keep_frames_parsed` will have their corresponding `src_mask` set to 1 (indicating inpainting/regeneration), and `src_video` pixels might be zeroed out.
- - Padding videos/masks to the required `video_length`.
- - Preparing reference images.
- - The outputs `src_video`, `src_mask`, and `src_ref_images` are the tensors that will be fed into the VAE encoding part of the `WanT2V.generate` method.
- * **Calling `WanT2V.generate` (lines 3053-3091)**:
- * The `src_video`, `src_mask`, and `src_ref_images` tensors (prepared above) are passed to `wan_model.generate()` as `input_frames`, `input_masks`, and `input_ref_images` respectively.
- * This triggers the VACE encoding pipeline within `WanT2V` (using `vace_encode_frames`, `vace_encode_masks`, `vace_latent`) and the subsequent diffusion process detailed in the sections above.
-
-This flow shows how `wgp.py` translates user interactions and media uploads from the Gradio interface into the structured tensor inputs required by the `WanT2V` class for VACE-conditioned video generation.
-
-## Supporting Multiple Wan Encodings (Proposed Extension)
-
-The existing VACE system processes a single set of control inputs (frames, masks, reference images) to create one `vace_context`. To support multiple, separate Wan Encodings for a single generation, the following modifications could be considered:
-
-### 1. Input Handling and Encoding
-
-* **UI Enhancements (`wgp.py`)**:
- * The user interface would need to allow users to define multiple "control groups". Each group could consist of its own video guide, reference images, masks, and associated parameters (e.g., type of control like Pose, Depth, general VACE).
-* **Task Processing (`wgp.py`)**:
- * `process_prompt_and_add_tasks` would need to gather these multiple input groups, validating each one.
-* **Source Preparation (`WanT2V.prepare_source`)**:
- * This method would need to be called for each control group, or be adapted to process a list of input groups. It would output lists of tensors, e.g., `list_of_src_videos`, `list_of_src_masks`, `list_of_src_ref_images`.
-* **VACE Encoding (`WanT2V.generate`)**:
- * The method would receive these lists of source tensors.
- * It would iterate through each set of `src_video_i`, `src_mask_i`, `src_ref_images_i`.
- * For each set, it would call `self.vace_encode_frames(...)` and `self.vace_encode_masks(...)` to produce an individual `vace_context_i`.
- * The result would be a list of VACE contexts: `list_of_vace_contexts = [context1, context2, ..., contextN]`.
- * This list would be passed to the `WanModel`'s forward pass, for example, via `kwargs.update({'vace_contexts': list_of_vace_contexts, ...})`.
-
-### 2. `WanModel` Adaptations
-
-* **Initialization (`WanModel.__init__`)**:
- * The model would need to be aware of the number of VACE streams (e.g., via a `num_vace_streams` parameter).
- * `self.vace_patch_embedding`: This would likely become an `nn.ModuleList` of `nn.Conv3d` layers. Each convolutional layer in this list would correspond to one VACE stream, processing its respective context. These layers would be designed to output features that can be combined (e.g., summed).
-* **Forward Pass (`WanModel.forward`)**:
- * The method would accept `vace_contexts` (a list of tensors) instead of a single `vace_context`.
- * It would iterate through the `vace_contexts` list and the `self.vace_patch_embeddings` ModuleList, applying the corresponding patch embedding to each context: `processed_c_i = self.vace_patch_embeddings[i](vace_context_i)`.
- * This results in a list of processed context tensors: `processed_contexts = [pc1, pc2, ..., pcN]`.
- * The `hints_list` provided to the model's blocks would be structured so that each VACE-enabled block receives this full list `[pc1, pc2, ..., pcN]`. For example, `hints_list = [processed_contexts for _ in range(len(x_list))]`.
-
-### 3. `WanAttentionBlock` and `VaceWanAttentionBlock` Adaptations
-
-* **`WanAttentionBlock.forward`**:
- * When `self.block_id is not None` (it's a VACE layer), the `hints` argument it receives will be the list `[pc1, pc2, ..., pcN]`.
- * It passes this entire list to its `self.vace` module (the `VaceWanAttentionBlock` instance): `hint_output = self.vace(hints, x, **kwargs)`.
- * `hint_output` is expected to be a single tensor representing the combined influence from all VACE streams for that block.
- * This combined `hint_output` is added to the main feature map `x`: `x.add_(hint_output, alpha=context_scale)`. The `context_scale` would apply to the combined hint.
-* **`VaceWanAttentionBlock.forward`**: This block is responsible for combining the multiple VACE streams.
- * It receives the `hints` list (`[pc1, pc2, ..., pcN]`) and the current feature map `x` from the main U-Net path.
- * **Context Combination**: The list of processed context tensors `pc_i` are combined into a single tensor. Summation is a straightforward approach: `combined_pc = torch.sum(torch.stack(hints), dim=0)`. This assumes the patch embeddings have produced features in a compatible space.
- * **Feature Integration**:
- * Let `c_attention_input = combined_pc`.
- * If `self.block_id == 0` (typically the first VACE-enabled block in the model hierarchy for a given resolution): The combined context is first projected by `self.before_proj` and then added to the main path's features `x`. `c_attention_input = self.before_proj(combined_pc) + x`.
- * Else (for subsequent VACE blocks): `c_attention_input` can be just `combined_pc`, which is then passed to `super().forward()`. (Alternatively, `self.before_proj(combined_pc)` could be used consistently if `before_proj` is seen as a general adapter for the combined context before attention, regardless of `block_id`.)
- * **Core Processing**: The `c_attention_input` (which is the combined VACE information, potentially mixed with `x`) is processed through the standard attention and FFN layers of the parent `WanAttentionBlock` by calling `c_processed = super().forward(c_attention_input, **kwargs)`.
- * **Final Projection**: The output `c_processed` is passed through `self.after_proj` to get the final skip connection value for this block: `c_skip = self.after_proj(c_processed)`.
- * **Return Value**: The `c_skip` tensor is returned. This tensor represents the aggregated influence of all active Wan Encodings for the current block.
- * The original in-place update `hints[0] = c` (where `c` was `c_processed`) would likely be removed or re-evaluated, as its role in a multi-context scenario with independent streams is less clear and the documentation suggests hints are passed fresh from `WanModel.forward` in each step.
-
-By implementing these changes, the Wan2GP framework could leverage multiple, diverse control signals simultaneously, offering more nuanced and complex control over the video generation process.
\ No newline at end of file
diff --git a/generate_test_tasks.py b/generate_test_tasks.py
new file mode 100644
index 000000000..5e98f0a66
--- /dev/null
+++ b/generate_test_tasks.py
@@ -0,0 +1,614 @@
+#!/usr/bin/env python3
+"""
+Generate a set of end-to-end headless-server tasks that exercise the
+most important travel-segment scenarios:
+
+1. single image, no continue-video
+2. 3 images, no continue-video
+3. continue-video + single image
+4. continue-video + 2 images
+
+For each test a JSON payload (compatible with add_task.py) is written
+under tests//_task.json and the referenced
+input files (images / video) are copied into that same directory with a
+clean, descriptive name so the folder is self-contained and easy to
+inspect.
+
+After writing the JSON the script *optionally* enqueues the task by
+invoking python add_task.py --type travel_orchestrator --params
+@
+
+Set --enqueue to actually perform the enqueue.
+"""
+from __future__ import annotations
+
+import argparse
+import json
+import shutil
+import sqlite3
+import subprocess
+import time
+import sys
+from pathlib import Path
+from datetime import datetime
+from typing import Dict, List, Tuple
+
+# ---------------------------------------------------------------------
+# Configuration helpers
+# ---------------------------------------------------------------------
+PROJECT_ID = "test_suite"
+BASE_PROMPT = "Car driving through a city, sky morphing"
+NEG_PROMPT = "chaotic"
+MODEL_NAME = "vace_14B" # Keeping VACE 14B model and will add proper image references
+DEFAULT_RESOLUTION = "500x500" # Fallback / default when not overridden
+FPS = 16
+SEGMENT_FRAMES_DEFAULT = 81 # will be quantised downstream (4n+1)
+FRAME_OVERLAP_DEFAULT = 12
+SEED_BASE = 11111
+
+# Determine output directory based on database type
+def get_output_dir_default():
+ """Get the appropriate output directory based on DB configuration."""
+ try:
+ # Import DB config to check type
+ from source import db_operations as db_ops
+ if db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH:
+ # For SQLite, use public/files to match the system convention
+ sqlite_db_parent = Path(db_ops.SQLITE_DB_PATH).resolve().parent
+ return str(sqlite_db_parent / "public" / "files")
+ else:
+ # For other DB types or when SQLite isn't configured, use outputs
+ return "./outputs"
+ except Exception:
+ # Fallback if DB config isn't available
+ return "./outputs"
+
+OUTPUT_DIR_DEFAULT = get_output_dir_default()
+
+SAMPLES_DIR = Path("samples")
+ASSET_IMAGES = [
+ SAMPLES_DIR / "1.png",
+ SAMPLES_DIR / "2.png",
+ SAMPLES_DIR / "3.png",
+]
+ASSET_VIDEO = SAMPLES_DIR / "test.mp4"
+
+TESTS_ROOT = Path("tests")
+
+# ---------------------------------------------------------------------
+
+def make_orchestrator_payload(*, run_id: str,
+ images: list[Path],
+ continue_video: Path | None,
+ num_segments: int,
+ resolution: str = DEFAULT_RESOLUTION,
+ additional_loras: dict | None = None) -> dict:
+ """Create the orchestrator_details dict used by headless server."""
+
+ # No VACE image references by default - images are used for guide video creation only
+ vace_image_refs = None
+
+ payload: dict = {
+ "run_id": run_id,
+ "input_image_paths_resolved": [str(p) for p in images],
+ "parsed_resolution_wh": resolution,
+ "model_name": MODEL_NAME,
+ "use_causvid_lora": True,
+ "num_new_segments_to_generate": num_segments,
+ "base_prompts_expanded": [BASE_PROMPT] * num_segments,
+ "negative_prompts_expanded": [NEG_PROMPT] * num_segments,
+ "segment_frames_expanded": [SEGMENT_FRAMES_DEFAULT] * num_segments,
+ "frame_overlap_expanded": [FRAME_OVERLAP_DEFAULT] * num_segments,
+ "fps_helpers": FPS,
+ "vace_image_refs_to_prepare_by_headless": vace_image_refs, # None instead of empty list
+ "fade_in_params_json_str": json.dumps({
+ "low_point": 0.0, "high_point": 1.0,
+ "curve_type": "ease_in_out", "duration_factor": 0.0
+ }),
+ "fade_out_params_json_str": json.dumps({
+ "low_point": 0.0, "high_point": 1.0,
+ "curve_type": "ease_in_out", "duration_factor": 0.0
+ }),
+ "seed_base": SEED_BASE,
+ "main_output_dir_for_run": OUTPUT_DIR_DEFAULT,
+ "debug_mode_enabled": True,
+ "skip_cleanup_enabled": True,
+ }
+ if continue_video is not None:
+ payload["continue_from_video_resolved_path"] = str(continue_video)
+
+ # Attach additional LoRAs if provided
+ if additional_loras:
+ payload["additional_loras"] = additional_loras
+ return payload
+
+
+def write_travel_test_case(name: str,
+ images: list[Path],
+ continue_video: Path | None,
+ num_segments: int,
+ resolution: str,
+ enqueue: bool,
+ additional_loras: dict | None = None) -> None:
+ """Materialise a test folder with JSON and asset copies."""
+ test_dir = TESTS_ROOT / name
+ test_dir.mkdir(parents=True, exist_ok=True)
+
+ # Copy / link assets with descriptive names and record new paths
+ copied_images: list[Path] = []
+ for idx, img in enumerate(images):
+ dst = test_dir / f"{name}_image{idx+1}{img.suffix}"
+ if dst.exists():
+ dst.unlink()
+ shutil.copy2(img, dst)
+ copied_images.append(dst)
+ if continue_video is not None:
+ dst_vid = test_dir / f"{name}_continue_video{continue_video.suffix}"
+ if dst_vid.exists():
+ dst_vid.unlink()
+ shutil.copy2(continue_video, dst_vid)
+ continue_video_path = dst_vid
+ else:
+ continue_video_path = None
+
+ # Build JSON payload (wrapper around orchestrator_details)
+ run_id = f"{name}_{datetime.utcnow().strftime('%Y%m%dT%H%M%S')}"
+ orch_payload = make_orchestrator_payload(
+ run_id=run_id,
+ images=copied_images,
+ continue_video=continue_video_path,
+ num_segments=num_segments,
+ resolution=resolution,
+ additional_loras=additional_loras,
+ )
+ task_json: dict = {
+ "project_id": PROJECT_ID,
+ "orchestrator_details": orch_payload,
+ }
+
+ json_path = test_dir / f"{name}_task.json"
+ with open(json_path, "w", encoding="utf-8") as fp:
+ json.dump(task_json, fp, indent=2)
+
+ try:
+ print(f"[WRITE] {json_path.resolve().relative_to(Path.cwd().resolve())}")
+ except ValueError:
+ # Fallback if paths are on different drives or unrelated
+ print(f"[WRITE] {json_path.resolve()}")
+
+ if enqueue:
+ cmd = [sys.executable, "add_task.py", "--type", "travel_orchestrator",
+ "--params", f"@{json_path}"]
+ print("[ENQUEUE]", " ".join(cmd))
+ try:
+ subprocess.run(cmd, check=True)
+ except subprocess.CalledProcessError as e:
+ print(f"[ERROR] add_task failed: {e}")
+
+
+def copy_results_for_comparison():
+ """Copy all test inputs and outputs to a single comparison directory."""
+ comparison_dir = Path("test_results_comparison")
+ comparison_dir.mkdir(exist_ok=True)
+
+ print(f"[COMPARISON] Creating comparison directory: {comparison_dir}")
+
+ # Get task outputs from database
+ task_outputs = get_task_outputs_from_db()
+
+ # Copy inputs and outputs for each test
+ for test_name in [
+ "travel_3_images_512",
+ "continue_video_1_image_512",
+ "different_perspective_pose_700x400",
+ "different_perspective_depth_640x480",
+ "single_image_1",
+ "single_image_2",
+ "single_image_3",
+ "single_image_4",
+ "single_image_5",
+ ]:
+ test_dir = Path("tests") / test_name
+ if not test_dir.exists():
+ print(f"[WARNING] Test directory not found: {test_dir}")
+ continue
+
+ # Create test-specific comparison subdirectory
+ comp_test_dir = comparison_dir / test_name
+ comp_test_dir.mkdir(exist_ok=True)
+
+ # Copy input files
+ inputs_dir = comp_test_dir / "inputs"
+ inputs_dir.mkdir(exist_ok=True)
+
+ for input_file in test_dir.glob("*"):
+ if input_file.is_file() and not input_file.name.endswith("_task.json"):
+ dst = inputs_dir / input_file.name
+ shutil.copy2(input_file, dst)
+ print(f"[COPY INPUT] {input_file} -> {dst}")
+
+ # Copy task JSON for reference
+ task_json = test_dir / f"{test_name}_task.json"
+ if task_json.exists():
+ shutil.copy2(task_json, comp_test_dir / "task_config.json")
+ print(f"[COPY CONFIG] {task_json} -> {comp_test_dir / 'task_config.json'}")
+
+ # Copy output files from database paths
+ outputs_dir = comp_test_dir / "outputs"
+ outputs_dir.mkdir(exist_ok=True)
+
+ # Copy outputs for this test from database
+ test_pattern = test_name.replace("test_", "").replace("_", "")
+ copied_count = 0
+ for task_id, output_path, status in task_outputs:
+ if test_pattern in task_id and output_path and status == "Complete":
+ # Convert database path (files/xyz.mp4) to full path (public/files/xyz.mp4)
+ if output_path.startswith("files/"):
+ full_output_path = Path("public") / output_path
+ else:
+ full_output_path = Path(output_path)
+
+ if full_output_path.exists():
+ dst = outputs_dir / f"{test_name}_{full_output_path.name}"
+ shutil.copy2(full_output_path, dst)
+ print(f"[COPY OUTPUT] {full_output_path} -> {dst}")
+ copied_count += 1
+ else:
+ print(f"[WARNING] Output file not found: {full_output_path}")
+
+ if copied_count == 0:
+ print(f"[INFO] No completed outputs found for {test_name}")
+
+ print(f"[COMPARISON] Results comparison ready in: {comparison_dir}")
+ return comparison_dir
+
+
+def get_task_outputs_from_db() -> List[Tuple[str, str, str]]:
+ """Query tasks.db for task outputs. Returns list of (task_id, output_location, status)."""
+ try:
+ conn = sqlite3.connect("tasks.db")
+ cursor = conn.cursor()
+
+ # Get all tasks with their output locations and status
+ cursor.execute("""
+ SELECT id, output_location, status
+ FROM tasks
+ WHERE project_id = 'test_suite'
+ ORDER BY created_at DESC
+ """)
+
+ results = cursor.fetchall()
+ conn.close()
+
+ print(f"[DB] Found {len(results)} test suite tasks in database")
+ return results
+
+ except Exception as e:
+ print(f"[ERROR] Failed to query database: {e}")
+ return []
+
+
+def wait_for_task_completion(max_wait_minutes: int = 30) -> None:
+ """Wait for all test_suite tasks to complete."""
+ print(f"[WAIT] Monitoring task completion (max {max_wait_minutes} minutes)...")
+
+ start_time = time.time()
+ max_wait_seconds = max_wait_minutes * 60
+
+ while True:
+ try:
+ conn = sqlite3.connect("tasks.db")
+ cursor = conn.cursor()
+
+ # Check for queued or in-progress test suite tasks
+ cursor.execute("""
+ SELECT COUNT(*)
+ FROM tasks
+ WHERE project_id = 'test_suite'
+ AND status IN ('Queued', 'In Progress')
+ """)
+
+ pending_count = cursor.fetchone()[0]
+
+ # Get completed/failed counts
+ cursor.execute("""
+ SELECT status, COUNT(*)
+ FROM tasks
+ WHERE project_id = 'test_suite'
+ GROUP BY status
+ """)
+
+ status_counts = dict(cursor.fetchall())
+ conn.close()
+
+ if pending_count == 0:
+ print(f"[WAIT] All tasks completed! Status summary: {status_counts}")
+ break
+
+ elapsed = time.time() - start_time
+ if elapsed > max_wait_seconds:
+ print(f"[WAIT] Timeout reached. Still pending: {pending_count}, Status: {status_counts}")
+ break
+
+ print(f"[WAIT] {pending_count} tasks still pending... (elapsed: {elapsed/60:.1f}min)")
+ time.sleep(10) # Check every 10 seconds
+
+ except Exception as e:
+ print(f"[ERROR] Failed to check task status: {e}")
+ break
+
+
+# ---------------------------------------------------------------------
+# New helper: Different-pose test case
+# ---------------------------------------------------------------------
+
+
+def write_different_perspective_test_case(name: str,
+ input_image: Path,
+ prompt: str,
+ resolution: str,
+ perspective_type: str = "pose",
+ enqueue: bool = False,
+ additional_loras: dict | None = None) -> None:
+ """Generate a different_perspective orchestrator task (single-image)."""
+ test_dir = TESTS_ROOT / name
+ test_dir.mkdir(parents=True, exist_ok=True)
+
+ # Copy input image into test dir for self-contained folder
+ dst_img = test_dir / f"{name}{input_image.suffix}"
+ if dst_img.exists():
+ dst_img.unlink()
+ shutil.copy2(input_image, dst_img)
+
+ run_id = f"{name}_{datetime.utcnow().strftime('%Y%m%dT%H%M%S')}"
+
+ task_json: dict = {
+ "project_id": PROJECT_ID,
+ "run_id": run_id,
+ "input_image_path": str(dst_img),
+ "prompt": prompt,
+ "model_name": MODEL_NAME,
+ "resolution": resolution,
+ "fps_helpers": FPS,
+ "output_video_frames": 30,
+ "seed": SEED_BASE,
+ "use_causvid_lora": True,
+ "debug_mode": True,
+ "skip_cleanup": True,
+ "perspective_type": perspective_type,
+ }
+
+ if additional_loras:
+ task_json["additional_loras"] = additional_loras
+
+ json_path = test_dir / f"{name}_task.json"
+ with open(json_path, "w", encoding="utf-8") as fp:
+ json.dump(task_json, fp, indent=2)
+
+ print(f"[WRITE] {json_path}")
+
+ if enqueue:
+ cmd = [sys.executable, "add_task.py", "--type", "different_perspective_orchestrator",
+ "--params", f"@{json_path}"]
+ print("[ENQUEUE]", " ".join(cmd))
+ try:
+ subprocess.run(cmd, check=True)
+ except subprocess.CalledProcessError as e:
+ print(f"[ERROR] add_task failed: {e}")
+
+
+# ---------------------------------------------------------------------
+# New helper: Single-image test case
+# ---------------------------------------------------------------------
+
+
+def write_single_image_test_case(name: str,
+ prompt: str,
+ resolution: str,
+ enqueue: bool,
+ additional_loras: dict | None = None) -> None:
+ """Create a single-image generation task (wgp single frame)."""
+ test_dir = TESTS_ROOT / name
+ test_dir.mkdir(parents=True, exist_ok=True)
+
+ run_id = f"{name}_{datetime.utcnow().strftime('%Y%m%dT%H%M%S')}"
+
+ task_json: dict = {
+ "project_id": PROJECT_ID,
+ "run_id": run_id,
+ "prompt": prompt,
+ "model": MODEL_NAME, # 'single_image' handler expects key 'model'
+ "resolution": resolution,
+ "seed": SEED_BASE,
+ "negative_prompt": NEG_PROMPT,
+ "use_causvid_lora": True,
+ }
+
+ if additional_loras:
+ task_json["additional_loras"] = additional_loras
+
+ json_path = test_dir / f"{name}_task.json"
+ with open(json_path, "w", encoding="utf-8") as fp:
+ json.dump(task_json, fp, indent=2)
+
+ print(f"[WRITE] {json_path}")
+
+ if enqueue:
+ cmd = [sys.executable, "add_task.py", "--type", "single_image",
+ "--params", f"@{json_path}"]
+ print("[ENQUEUE]", " ".join(cmd))
+ try:
+ subprocess.run(cmd, check=True)
+ except subprocess.CalledProcessError as e:
+ print(f"[ERROR] add_task failed: {e}")
+
+
+# ---------------------------------------------------------------------
+# Main CLI
+# ---------------------------------------------------------------------
+
+def main(args) -> None:
+ TESTS_ROOT.mkdir(exist_ok=True)
+
+ # Parse --loras JSON string into dict if provided
+ additional_loras_global: dict | None = None
+ if hasattr(args, "loras") and args.loras:
+ try:
+ additional_loras_global = json.loads(args.loras)
+ except Exception as e_l_json:
+ print(f"[ERROR] Failed to parse --loras JSON: {e_l_json}")
+ sys.exit(1)
+
+ # Generate tasks according to requested --task-type
+
+ # ------------------------------------------------------------
+ # travel_between_images
+ # ------------------------------------------------------------
+ if args.task_type == "travel_between_images":
+
+ # 1) Continue-video + 1 image → 512×512
+
+ write_travel_test_case(
+ name="continue_video_1_image_512",
+ images=[ASSET_IMAGES[0]],
+ continue_video=ASSET_VIDEO,
+ num_segments=1,
+ resolution="512x512",
+ enqueue=args.enqueue,
+ additional_loras=additional_loras_global,
+ )
+
+ # 2) Travel orchestrator: 3 images → 512×512
+ write_travel_test_case(
+ name="travel_3_images_512",
+ images=ASSET_IMAGES,
+ continue_video=None,
+ num_segments=3,
+ resolution="512x512",
+ enqueue=args.enqueue,
+ additional_loras=additional_loras_global,
+ )
+
+
+
+ # ------------------------------------------------------------
+ # different_perspective
+ # ------------------------------------------------------------
+ elif args.task_type == "different_perspective":
+ # Test pose-based perspective change
+ write_different_perspective_test_case(
+ name="different_perspective_pose_700x400",
+ input_image=SAMPLES_DIR / "pose.png",
+ prompt="Person standing in a desert sunset, cinematic lighting",
+ resolution="700x400",
+ perspective_type="pose",
+ enqueue=args.enqueue,
+ additional_loras=additional_loras_global,
+ )
+
+ # Test depth-based perspective change
+ write_different_perspective_test_case(
+ name="different_perspective_depth_640x480",
+ input_image=SAMPLES_DIR / "1.png",
+ prompt="Cinematic view from a different angle, dramatic lighting",
+ resolution="640x480",
+ perspective_type="depth",
+ enqueue=args.enqueue,
+ additional_loras=additional_loras_global,
+ )
+
+ # ------------------------------------------------------------
+ # single_image
+ # ------------------------------------------------------------
+ elif args.task_type == "single_image":
+ single_image_specs = [
+ ("single_image_1", "studio ghibli style: A serene mountain landscape at sunset", "640x360"),
+ ("single_image_2", "studio ghibli style: A futuristic city skyline at night", "512x768"),
+ ("single_image_3", "studio ghibli style: A cute puppy in a field of flowers", "768x512"),
+ ("single_image_4", "studio ghibli style: A spaceship traveling through hyperspace", "720x1280"),
+ ("single_image_5", "studio ghibli style: A vibrant abstract painting of shapes and colours", "1024x576"),
+ ]
+
+ ghibli_lora_url = "https://huggingface.co/peteromallet/ad_motion_loras/resolve/main/studio_ghibli_wan14b_t2v_v01.safetensors"
+
+ for name, prompt, res in single_image_specs:
+ # Ensure every single-image task applies the Studio-Ghibli LoRA (strength 1.0)
+ loras_to_use = {ghibli_lora_url: 1.0}
+ # Merge any global --loras values (caller-supplied) so they stack as well
+ if additional_loras_global:
+ loras_to_use.update(additional_loras_global)
+
+ write_single_image_test_case(
+ name=name,
+ prompt=prompt,
+ resolution=res,
+ enqueue=args.enqueue,
+ additional_loras=loras_to_use,
+ )
+
+ print("\nAll test cases generated under", TESTS_ROOT.resolve())
+
+ if args.compare:
+ if not args.no_wait:
+ print("[INFO] Waiting for task completion before comparison (use --no-wait to skip)")
+ wait_for_task_completion(args.wait_minutes)
+ copy_results_for_comparison()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("--enqueue", action="store_true",
+ help="Only generate tests and enqueue them (no comparison)")
+ parser.add_argument("--compare", action="store_true", help="Only create comparison directory (no new enqueue)")
+ parser.add_argument("--no-wait", action="store_true", help="Skip waiting for task completion (compare immediately)")
+ parser.add_argument("--wait-minutes", type=int, default=30, help="Max minutes to wait for completion when waiting is enabled")
+ parser.add_argument(
+ "--task-type",
+ choices=["different_perspective", "travel_between_images", "single_image"],
+ default="different_perspective",
+ help=(
+ "Select which kind of test task(s) to generate. "
+ "'different_perspective' (default) creates both pose and depth perspective-variation tasks; "
+ "'travel_between_images' creates the orchestrator travel tasks; "
+ "'single_image' creates the five single-image tasks."
+ ),
+ )
+ parser.add_argument(
+ "--loras",
+ type=str,
+ help="JSON dictionary mapping LoRA URLs/paths to strength values to attach to generated tasks (e.g. '{\"path_or_url\": 0.8}')",
+ )
+ args = parser.parse_args()
+
+ # -------------------------------------------------
+ # 1. Explicit --enqueue only
+ # -------------------------------------------------
+ if args.enqueue and not args.compare:
+ main(args) # generate & enqueue only
+ print("[INFO] Enqueue-only mode complete. No comparison requested.")
+ sys.exit(0)
+
+ # -------------------------------------------------
+ # 2. Explicit --compare only
+ # -------------------------------------------------
+ if args.compare and not args.enqueue:
+ if not args.no_wait:
+ print("[INFO] Waiting for task completion before comparison (use --no-wait to skip)")
+ wait_for_task_completion(args.wait_minutes)
+ copy_results_for_comparison()
+ sys.exit(0)
+
+ # -------------------------------------------------
+ # 3. DEFAULT / combined path: enqueue + wait + compare
+ # (Triggered when no flags or both flags given.)
+ # -------------------------------------------------
+ print("[INFO] Running full cycle: generate/enqueue tests, wait for completion, then compare results.")
+ main(args) # generate and enqueue tests
+
+ if not args.no_wait:
+ wait_for_task_completion(args.wait_minutes)
+ else:
+ print("[INFO] --no-wait specified: skipping wait phase.")
+
+ copy_results_for_comparison()
+ print("[INFO] Full test cycle complete.")
\ No newline at end of file
diff --git a/headless.py b/headless.py
index d7e32759f..88aefb802 100644
--- a/headless.py
+++ b/headless.py
@@ -18,98 +18,48 @@
import argparse
import sys
import os
-import types
-from pathlib import Path
-from PIL import Image
import json
import time
-# import shutil # No longer moving files, tasks are removed from tasks.json
+import datetime
import traceback
-import requests # For downloading the LoRA
-import inspect # Added import
-import sqlite3 # Added for SQLite database
-import urllib.parse # Added for URL encoding
-import threading
-import uuid # Added import for UUID
-
-# --- Add imports for OpenPose generation ---
-import numpy as np
-try:
- # Import from Wan2GP submodule
- import sys
- from pathlib import Path
- wan2gp_path = Path(__file__).parent / "Wan2GP"
- if str(wan2gp_path) not in sys.path:
- sys.path.insert(0, str(wan2gp_path))
- from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
-except ImportError:
- PoseBodyFaceVideoAnnotator = None # Allow script to load if module not found, error out at runtime
-# --- End OpenPose imports ---
-
-from dotenv import load_dotenv # For .env file
-# import psycopg2 # For PostgreSQL - REMOVED
-# import psycopg2.extras # For dictionary cursor with PostgreSQL - REMOVED
-from supabase import create_client, Client as SupabaseClient # For Supabase
-import tempfile # For temporary directories
-import shutil # For file operations
-import cv2 # Added for RIFE interpolation
+import urllib.parse
+import tempfile
+import shutil
+
+from pathlib import Path
+from PIL import Image
+from dotenv import load_dotenv
+from supabase import create_client, Client as SupabaseClient
+
+# Add the current directory to Python path so Wan2GP can be imported as a module
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+# Add the Wan2GP subdirectory to the path for its internal imports
+wan2gp_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "Wan2GP")
+if wan2gp_path not in sys.path:
+ sys.path.append(wan2gp_path)
# --- SM_RESTRUCTURE: Import moved/new utilities ---
-from Wan2GP.sm_functions.common_utils import (
- # dprint is already defined locally in headless.py
- generate_unique_task_id as sm_generate_unique_task_id, # Alias to avoid conflict if headless has its own
- add_task_to_db as sm_add_task_to_db,
- # poll_task_status, # headless doesn't poll for sub-tasks this way
- get_video_frame_count_and_fps as sm_get_video_frame_count_and_fps,
- _get_unique_target_path as sm_get_unique_target_path,
- image_to_frame as sm_image_to_frame,
- create_color_frame as sm_create_color_frame,
- _adjust_frame_brightness as sm_adjust_frame_brightness,
- _copy_to_folder_with_unique_name as sm_copy_to_folder_with_unique_name,
- _apply_strength_to_image as sm_apply_strength_to_image,
- parse_resolution as sm_parse_resolution # For parsing resolution string from orchestrator
+from source import db_operations as db_ops
+from source.specialized_handlers import (
+ handle_generate_openpose_task,
+ handle_rife_interpolate_task,
+ handle_extract_frame_task
)
-from Wan2GP.sm_functions.video_utils import (
- extract_frames_from_video as sm_extract_frames_from_video,
- create_video_from_frames_list as sm_create_video_from_frames_list,
- cross_fade_overlap_frames as sm_cross_fade_overlap_frames,
- _apply_saturation_to_video_ffmpeg as sm_apply_saturation_to_video_ffmpeg,
- # color_match_video_to_reference # If needed by stitch/segment tasks
-)
-from Wan2GP.sm_functions.travel_between_images import (
- get_easing_function as sm_get_easing_function # For guide video fades
+from source.common_utils import (
+ sm_get_unique_target_path,
+ download_image_if_url as sm_download_image_if_url,
+ download_file,
+ load_pil_images as sm_load_pil_images,
+ build_task_state,
+ prepare_output_path_with_upload,
+ upload_and_get_final_output_location
)
+from source.sm_functions import travel_between_images as tbi
+from source.sm_functions import different_perspective as dp
+from source.sm_functions import single_image as si
+from source.sm_functions import magic_edit as me
# --- End SM_RESTRUCTURE imports ---
-# -----------------------------------------------------------------------------
-# Global DB Configuration (will be set in main)
-# -----------------------------------------------------------------------------
-DB_TYPE = "sqlite" # Default to sqlite, will be changed to "supabase" if configured
-# PG_DSN = None # REMOVED - Supabase client handles connection
-PG_TABLE_NAME = "tasks" # Still needed for RPC calls (table name in Supabase/Postgres)
-SQLITE_DB_PATH = "tasks.db"
-SUPABASE_URL = None
-SUPABASE_SERVICE_KEY = None
-SUPABASE_VIDEO_BUCKET = "videos" # Default bucket name, can be overridden by .env
-SUPABASE_CLIENT: SupabaseClient | None = None # Global Supabase client instance
-
-# Add SQLite connection lock for thread safety
-sqlite_lock = threading.Lock()
-
-# SQLite retry configuration
-SQLITE_MAX_RETRIES = 5
-SQLITE_RETRY_DELAY = 0.5 # seconds
-
-# Global variable for ComfyUI Output Path
-COMFYUI_OUTPUT_PATH_CONFIG: Path | None = None
-
-# -----------------------------------------------------------------------------
-# Status Constants
-# -----------------------------------------------------------------------------
-STATUS_QUEUED = "Queued"
-STATUS_IN_PROGRESS = "In Progress"
-STATUS_COMPLETE = "Complete"
-STATUS_FAILED = "Failed"
# -----------------------------------------------------------------------------
# Debug / Verbose Logging Helpers
@@ -121,7 +71,7 @@ def dprint(msg: str):
"""Print a debug message if --debug flag is enabled."""
if debug_mode:
# Prefix with timestamp for easier tracing
- print(f"[DEBUG {time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
+ print(f"[DEBUG {datetime.datetime.now().isoformat()}] {msg}")
# -----------------------------------------------------------------------------
# 1. Parse arguments for the server
@@ -139,6 +89,33 @@ def parse_args():
help="How often (in seconds) to check tasks.json for new tasks.")
pgroup_server.add_argument("--debug", action="store_true",
help="Enable verbose debug logging (prints additional diagnostics)")
+ pgroup_server.add_argument("--worker", type=str, default=None,
+ help="Worker name/ID - creates a log file named {worker}.log in the logs folder")
+ pgroup_server.add_argument("--save-logging", type=str, nargs='?', const='logs/headless.log', default=None,
+ help="Save all logging output to a file (in addition to console output). Optionally specify path, defaults to 'logs/headless.log'")
+ pgroup_server.add_argument("--delete-db", action="store_true",
+ help="Delete existing database files before starting (fresh start)")
+ pgroup_server.add_argument("--migrate-only", action="store_true",
+ help="Run database migrations and then exit.")
+ pgroup_server.add_argument("--apply-reward-lora", action="store_true",
+ help="Apply the reward LoRA with a fixed strength of 0.5.")
+ pgroup_server.add_argument("--colour-match-videos", action="store_true",
+ help="Apply colour matching to travel videos.")
+ # --- New flag: automatically generate and pass a video mask marking active/inactive frames ---
+ pgroup_server.add_argument("--mask-active-frames", dest="mask_active_frames", action="store_true", default=True,
+ help="Generate and pass a mask video where frames that are re-used remain unmasked while new frames are masked (enabled by default).")
+ pgroup_server.add_argument("--no-mask-active-frames", dest="mask_active_frames", action="store_false",
+ help="Disable automatic mask video generation.")
+
+ # --- New Supabase-related arguments ---
+ pgroup_server.add_argument("--db-type", type=str, choices=["sqlite", "supabase"], default="sqlite",
+ help="Database type to use (default: sqlite)")
+ pgroup_server.add_argument("--supabase-url", type=str, default=None,
+ help="Supabase project URL (required if db_type = supabase)")
+ pgroup_server.add_argument("--supabase-access-token", type=str, default=None,
+ help="Supabase access token (JWT) for authentication (required if db_type = supabase)")
+ pgroup_server.add_argument("--supabase-anon-key", type=str, default=None,
+ help="Supabase anon (public) API key used to create the client when authenticating with a user JWT. If omitted, falls back to SUPABASE_ANON_KEY env var or service key.")
# Advanced wgp.py Global Config Overrides (Optional) - Applied once at server start
pgroup_wgp_globals = parser.add_argument_group("WGP Global Config Overrides (Applied at Server Start)")
@@ -185,2025 +162,576 @@ def decorator(func):
gr.SelectData = type('SelectData', (), {'index': None, '_data': None})
gr.EventData = type('EventData', (), {'target':None, '_data':None})
-# -----------------------------------------------------------------------------
-# Database Helper Functions with Improved SQLite Handling
-# -----------------------------------------------------------------------------
-
-def execute_sqlite_with_retry(db_path_str: str, operation_func, *args, **kwargs):
- """Execute SQLite operations with retry logic for handling locks and I/O errors"""
- for attempt in range(SQLITE_MAX_RETRIES):
- try:
- with sqlite_lock: # Ensure only one thread accesses SQLite at a time
- conn = sqlite3.connect(db_path_str, timeout=30.0) # 30 second timeout
- conn.execute("PRAGMA journal_mode=WAL") # Enable WAL mode for better concurrency
- conn.execute("PRAGMA synchronous=NORMAL") # Balance between safety and performance
- conn.execute("PRAGMA busy_timeout=30000") # 30 second busy timeout
- try:
- result = operation_func(conn, *args, **kwargs)
- conn.commit()
- return result
- finally:
- conn.close()
- except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
- error_msg = str(e).lower()
- if "database is locked" in error_msg or "disk i/o error" in error_msg or "database disk image is malformed" in error_msg:
- if attempt < SQLITE_MAX_RETRIES - 1:
- wait_time = SQLITE_RETRY_DELAY * (2 ** attempt) # Exponential backoff
- print(f"SQLite error on attempt {attempt + 1}: {e}. Retrying in {wait_time:.1f}s...")
- time.sleep(wait_time)
- continue
- else:
- print(f"SQLite error after {SQLITE_MAX_RETRIES} attempts: {e}")
- raise
- else:
- # For other SQLite errors, don't retry
- raise
- except Exception as e:
- # For non-SQLite errors, don't retry
- raise
-
- raise sqlite3.OperationalError(f"Failed to execute SQLite operation after {SQLITE_MAX_RETRIES} attempts")
-
-def _migrate_sqlite_schema(db_path_str: str):
- """Applies necessary schema migrations to an existing SQLite database."""
- dprint(f"SQLite Migration: Checking schema for {db_path_str}...")
- try:
- def migration_operations(conn):
- cursor = conn.cursor()
-
- # Check if task_type column exists
- cursor.execute(f"PRAGMA table_info(tasks)")
- columns = [row[1] for row in cursor.fetchall()]
- task_type_column_exists = 'task_type' in columns
-
- if not task_type_column_exists:
- dprint("SQLite Migration: 'task_type' column not found. Adding it.")
- cursor.execute("ALTER TABLE tasks ADD COLUMN task_type TEXT") # Add as nullable first
- conn.commit() # Commit alter table before data migration
- dprint("SQLite Migration: 'task_type' column added.")
- else:
- dprint("SQLite Migration: 'task_type' column already exists.")
-
- # Populate task_type from params if it's NULL (for old rows or newly added column)
- dprint("SQLite Migration: Attempting to populate NULL 'task_type' from 'params' JSON...")
- cursor.execute("SELECT task_id, params FROM tasks WHERE task_type IS NULL")
- rows_to_migrate = cursor.fetchall()
-
- migrated_count = 0
- for task_id, params_json_str in rows_to_migrate:
- try:
- params_dict = json.loads(params_json_str)
- # Attempt to get task_type from common old locations within params
- # The user might need to adjust these keys if their old storage was different
- old_task_type = params_dict.get("task_type") # Most likely if it was in params
-
- if old_task_type:
- dprint(f"SQLite Migration: Found task_type '{old_task_type}' in params for task_id {task_id}. Updating row.")
- cursor.execute("UPDATE tasks SET task_type = ? WHERE task_id = ?", (old_task_type, task_id))
- migrated_count += 1
- else:
- # If task_type is not in params, it might be inferred from 'model' or other fields
- # For instance, if 'model' field implied the task type for older tasks.
- # This part is highly dependent on previous conventions.
- # As a simple default, if not found, it will remain NULL unless a default is set.
- # For 'travel_between_images' and 'different_pose', these are typically set by steerable_motion.py
- # and wouldn't exist as 'task_type' inside params for headless.py's default processing.
- # Headless tasks like 'generate_openpose' *did* use task_type in params.
- dprint(f"SQLite Migration: No 'task_type' key in params for task_id {task_id}. It will remain NULL or needs manual/specific migration logic if it was inferred differently.")
- except json.JSONDecodeError:
- dprint(f"SQLite Migration: Could not parse params JSON for task_id {task_id}. Skipping 'task_type' population for this row.")
- except Exception as e_row:
- dprint(f"SQLite Migration: Error processing row for task_id {task_id}: {e_row}")
-
- if migrated_count > 0:
- conn.commit()
- dprint(f"SQLite Migration: Populated 'task_type' for {migrated_count} rows from params.")
-
- # Default remaining NULL task_types for old standard tasks
- # This ensures rows that didn't have an explicit 'task_type' in their params (e.g. old default WGP tasks)
- # get a value, respecting the NOT NULL constraint if the table is new or fully validated.
- default_task_type_for_old_rows = "standard_wgp_task"
- cursor.execute(
- f"UPDATE tasks SET task_type = ? WHERE task_type IS NULL",
- (default_task_type_for_old_rows,)
- )
- updated_to_default_count = cursor.rowcount
- if updated_to_default_count > 0:
- conn.commit()
- dprint(f"SQLite Migration: Updated {updated_to_default_count} older rows with NULL task_type to default '{default_task_type_for_old_rows}'.")
-
- dprint("SQLite Migration: Schema check and population attempt complete.")
- return True
-
- execute_sqlite_with_retry(db_path_str, migration_operations)
-
- except Exception as e:
- print(f"[ERROR] SQLite Migration: Failed to migrate schema for {db_path_str}: {e}")
- traceback.print_exc()
- # Depending on severity, you might want to sys.exit(1)
-
-def init_db(db_path_str: str):
- """Initialize the SQLite database with proper error handling"""
- def _init_operation(conn):
- cursor = conn.cursor()
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS tasks (
- task_id TEXT PRIMARY KEY,
- status TEXT NOT NULL DEFAULT 'Queued',
- params TEXT NOT NULL,
- task_type TEXT NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- output_location TEXT
- )
- """)
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_status_created ON tasks(status, created_at)")
- return True
-
- try:
- execute_sqlite_with_retry(db_path_str, _init_operation)
- print(f"SQLite database initialized: {db_path_str}")
- except Exception as e:
- print(f"Failed to initialize SQLite database: {e}")
- sys.exit(1)
-
-def get_oldest_queued_task(db_path_str: str):
- """Get the oldest queued task with proper error handling"""
- def _get_operation(conn):
- cursor = conn.cursor()
- cursor.execute("SELECT task_id, params, task_type FROM tasks WHERE status = 'Queued' ORDER BY created_at ASC LIMIT 1")
- task_row = cursor.fetchone()
- if task_row:
- return {"task_id": task_row[0], "params": json.loads(task_row[1]), "task_type": task_row[2]}
- return None
-
- try:
- return execute_sqlite_with_retry(db_path_str, _get_operation)
- except Exception as e:
- print(f"Error getting oldest queued task: {e}")
- return None
-
-def update_task_status(db_path_str: str, task_id: str, status: str, output_location_val: str | None = None):
- """Updates a task's status and updated_at timestamp with proper error handling"""
- def _update_operation(conn, task_id, status, output_location_val):
- cursor = conn.cursor()
- if status == STATUS_COMPLETE and output_location_val is not None:
- cursor.execute("UPDATE tasks SET status = ?, updated_at = CURRENT_TIMESTAMP, output_location = ? WHERE task_id = ?",
- (status, output_location_val, task_id))
- else:
- cursor.execute("UPDATE tasks SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE task_id = ?", (status, task_id))
- return True
-
- try:
- execute_sqlite_with_retry(db_path_str, _update_operation, task_id, status, output_location_val)
- dprint(f"SQLite: Updated status of task {task_id} to {status}. Output: {output_location_val if output_location_val else 'N/A'}")
- except Exception as e:
- print(f"Error updating task status for {task_id}: {e}")
- # Don't raise here to avoid crashing the main loop
-# -----------------------------------------------------------------------------
-# 3. Minimal send_cmd implementation (task_id instead of task_name)
-# -----------------------------------------------------------------------------
def make_send_cmd(task_id):
def _send(cmd, data=None):
prefix = f"[Task ID: {task_id}]"
if cmd == "progress":
if isinstance(data, list) and len(data) >= 2:
prog, txt = data[0], data[1]
- if isinstance(prog, tuple) and len(prog) == 2: step, total = prog; print(f"{prefix}[Progress] {step}/{total} – {txt}")
- else: print(f"{prefix}[Progress] {txt}")
- elif cmd == "status": print(f"{prefix}[Status] {data}")
- elif cmd == "info": print(f"{prefix}[INFO] {data}")
- elif cmd == "error": print(f"{prefix}[ERROR] {data}"); raise RuntimeError(f"wgp.py error for {task_id}: {data}")
- elif cmd == "output": print(f"{prefix}[Output] video written.")
-
- print(f"DEBUG: Signature of _send for task {task_id}: {inspect.signature(_send)}") # Added diagnostic print
+ if isinstance(prog, tuple) and len(prog) == 2:
+ step, total = prog
+ print(f"{prefix}[Progress] {step}/{total} – {txt}")
+ else:
+ print(f"{prefix}[Progress] {txt}")
+ elif cmd == "status":
+ print(f"{prefix}[Status] {data}")
+ elif cmd == "info":
+ print(f"{prefix}[INFO] {data}")
+ elif cmd == "error":
+ print(f"{prefix}[ERROR] {data}")
+ raise RuntimeError(f"wgp.py error for {task_id}: {data}")
+ elif cmd == "output":
+ print(f"{prefix}[Output] video written.")
return _send
-# -----------------------------------------------------------------------------
-# 4. State builder for a single task (same as before)
-# -----------------------------------------------------------------------------
-def build_task_state(wgp_mod, model_filename, task_params_dict, all_loras_for_model):
- state = {
- "model_filename": model_filename,
- "validate_success": 1,
- "advanced": True,
- "gen": {"queue": [], "file_list": [], "file_settings_list": [], "prompt_no": 1, "prompts_max": 1},
- "loras": all_loras_for_model,
- }
- model_type_key = wgp_mod.get_model_type(model_filename)
- ui_defaults = wgp_mod.get_default_settings(model_filename).copy()
-
- # Override with task_params from JSON, but preserve some crucial ones if CausVid is used
- causvid_active = task_params_dict.get("use_causvid_lora", False)
-
- for key, value in task_params_dict.items():
- if key not in ["output_sub_dir", "model", "task_id", "use_causvid_lora"]:
- if causvid_active and key in ["steps", "guidance_scale", "flow_shift", "activated_loras", "loras_multipliers"]:
- continue # These will be set by causvid logic if flag is true
- ui_defaults[key] = value
+def _ensure_lora_downloaded(task_id: str, lora_name: str, lora_url: str, target_dir: Path,
+ task_params_dict: dict, flag_key: str, model_filename: str = None,
+ model_requirements: list = None) -> bool:
+ """
+ Shared helper to download a LoRA if it doesn't exist.
- ui_defaults["prompt"] = task_params_dict.get("prompt", "Default prompt")
- ui_defaults["resolution"] = task_params_dict.get("resolution", "832x480")
- # Allow task to specify frames/video_length, steps, guidance_scale, flow_shift unless overridden by CausVid
- if not causvid_active:
- ui_defaults["video_length"] = task_params_dict.get("frames", task_params_dict.get("video_length", 81))
- ui_defaults["num_inference_steps"] = task_params_dict.get("steps", task_params_dict.get("num_inference_steps", 30))
- ui_defaults["guidance_scale"] = task_params_dict.get("guidance_scale", ui_defaults.get("guidance_scale", 5.0))
- ui_defaults["flow_shift"] = task_params_dict.get("flow_shift", ui_defaults.get("flow_shift", 3.0))
- else: # CausVid specific defaults if not touched by its logic yet
- ui_defaults["video_length"] = task_params_dict.get("frames", task_params_dict.get("video_length", 81))
- # steps, guidance_scale, flow_shift will be set below by causvid logic
-
- ui_defaults["seed"] = task_params_dict.get("seed", -1)
- ui_defaults["lset_name"] = ""
-
- def load_pil_images(paths_list_or_str, wgp_convert_func):
- if paths_list_or_str is None: return None
- paths_list = paths_list_or_str if isinstance(paths_list_or_str, list) else [paths_list_or_str]
- images = []
- for p_str in paths_list:
- p = Path(p_str.strip())
- if not p.is_file(): print(f"[WARNING] Image file not found: {p}"); continue
- try:
- img = Image.open(p)
- images.append(wgp_convert_func(img))
- except Exception as e:
- print(f"[WARNING] Failed to load image {p}: {e}")
- return images if images else None
-
- if task_params_dict.get("image_start_paths"):
- loaded = load_pil_images(task_params_dict["image_start_paths"], wgp_mod.convert_image)
- if loaded: ui_defaults["image_start"] = loaded
- if task_params_dict.get("image_end_paths"):
- loaded = load_pil_images(task_params_dict["image_end_paths"], wgp_mod.convert_image)
- if loaded: ui_defaults["image_end"] = loaded
- if task_params_dict.get("image_refs_paths"):
- loaded = load_pil_images(task_params_dict["image_refs_paths"], wgp_mod.convert_image)
- if loaded: ui_defaults["image_refs"] = loaded
+ Args:
+ task_id: Task identifier for logging
+ lora_name: Filename of the LoRA to download
+ lora_url: URL to download fromCan you add
+ target_dir: Directory to save the LoRA
+ task_params_dict: Task parameters dictionary to modify if download fails
+ flag_key: Key in task_params_dict to set to False if download fails
+ model_filename: Optional model filename for requirement checking
+ model_requirements: Optional list of strings that must be in model_filename
- for key in ["video_source_path", "video_guide_path", "video_mask_path", "audio_guide_path"]:
- if task_params_dict.get(key):
- ui_defaults[key.replace("_path","")] = task_params_dict[key]
-
- if task_params_dict.get("prompt_enhancer_mode"):
- ui_defaults["prompt_enhancer"] = task_params_dict["prompt_enhancer_mode"]
- wgp_mod.server_config["enhancer_enabled"] = 1
- elif "prompt_enhancer" not in task_params_dict:
- ui_defaults["prompt_enhancer"] = ""
- wgp_mod.server_config["enhancer_enabled"] = 0
-
- # --- Custom LoRA Handling (e.g., from lora_name) ---
- custom_lora_name_stem = task_params_dict.get("lora_name")
- task_id_for_dprint = task_params_dict.get('task_id', 'N/A') # For logging
-
- if custom_lora_name_stem:
- custom_lora_filename = f"{custom_lora_name_stem}.safetensors"
- dprint(f"[Task ID: {task_id_for_dprint}] Custom LoRA specified via lora_name: {custom_lora_filename}")
-
- # Ensure activated_loras is a list
- activated_loras_val = ui_defaults.get("activated_loras", [])
- if isinstance(activated_loras_val, str):
- # Handles comma-separated string from task_params or previous logic
- current_activated_list = [str(item).strip() for item in activated_loras_val.split(',') if item.strip()]
- elif isinstance(activated_loras_val, list):
- current_activated_list = list(activated_loras_val) # Ensure it's a mutable copy
- else:
- dprint(f"[Task ID: {task_id_for_dprint}] Unexpected type for activated_loras: {type(activated_loras_val)}. Initializing as empty list.")
- current_activated_list = []
-
- if custom_lora_filename not in current_activated_list:
- current_activated_list.append(custom_lora_filename)
- dprint(f"[Task ID: {task_id_for_dprint}] Added '{custom_lora_filename}' to activated_loras list: {current_activated_list}")
-
- # Handle multipliers: Add a default "1.0" if a LoRA was added and multipliers are potentially mismatched
- loras_multipliers_str = ui_defaults.get("loras_multipliers", "")
- if isinstance(loras_multipliers_str, (list, tuple)):
- loras_multipliers_list = [str(m).strip() for m in loras_multipliers_str if str(m).strip()] # Convert all to string and clean
- elif isinstance(loras_multipliers_str, str):
- loras_multipliers_list = [m.strip() for m in loras_multipliers_str.split(" ") if m.strip()] # Space-separated string
- else:
- dprint(f"[Task ID: {task_id_for_dprint}] Unexpected type for loras_multipliers: {type(loras_multipliers_str)}. Initializing as empty list.")
- loras_multipliers_list = []
-
- # If number of multipliers is less than activated LoRAs, pad with "1.0"
- while len(loras_multipliers_list) < len(current_activated_list):
- loras_multipliers_list.append("1.0")
- dprint(f"[Task ID: {task_id_for_dprint}] Padded loras_multipliers with '1.0'. Now: {loras_multipliers_list}")
-
- ui_defaults["loras_multipliers"] = " ".join(loras_multipliers_list)
- else:
- dprint(f"[Task ID: {task_id_for_dprint}] Custom LoRA '{custom_lora_filename}' already in activated_loras list.")
-
- ui_defaults["activated_loras"] = current_activated_list # Update ui_defaults
- # --- End Custom LoRA Handling ---
-
- # --- Handle remove_background_image_ref legacy key ---
- if "remove_background_image_ref" in ui_defaults and "remove_background_images_ref" not in ui_defaults:
- ui_defaults["remove_background_images_ref"] = ui_defaults["remove_background_image_ref"]
- # --- End Handle remove_background_image_ref ---
-
- # Apply CausVid LoRA specific settings if the flag is true
- if causvid_active:
- print(f"[Task ID: {task_params_dict.get('task_id')}] Applying CausVid LoRA settings.")
-
- # If steps are specified in the task JSON for a CausVid task, use them; otherwise, default to 9.
- if "steps" in task_params_dict:
- ui_defaults["num_inference_steps"] = task_params_dict["steps"]
- print(f"[Task ID: {task_params_dict.get('task_id')}] CausVid task using specified steps: {ui_defaults['num_inference_steps']}")
- elif "num_inference_steps" in task_params_dict:
- ui_defaults["num_inference_steps"] = task_params_dict["num_inference_steps"]
- print(f"[Task ID: {task_params_dict.get('task_id')}] CausVid task using specified num_inference_steps: {ui_defaults['num_inference_steps']}")
- else:
- ui_defaults["num_inference_steps"] = 9 # Default for CausVid if not specified in task
- print(f"[Task ID: {task_params_dict.get('task_id')}] CausVid task defaulting to steps: {ui_defaults['num_inference_steps']}")
-
- ui_defaults["guidance_scale"] = 1.0 # Still overridden
- ui_defaults["flow_shift"] = 1.0 # Still overridden
-
- causvid_lora_basename = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
- current_activated = ui_defaults.get("activated_loras", [])
- if not isinstance(current_activated, list):
- try: current_activated = [str(item).strip() for item in str(current_activated).split(',') if item.strip()]
- except: current_activated = []
-
- if causvid_lora_basename not in current_activated:
- current_activated.append(causvid_lora_basename)
- ui_defaults["activated_loras"] = current_activated
-
- current_multipliers_str = ui_defaults.get("loras_multipliers", "")
- # Basic handling: if multipliers exist, prepend; otherwise, set directly.
- # More sophisticated merging might be needed if specific order or pairing is critical.
- # This assumes multipliers are space-separated.
- if current_multipliers_str:
- multipliers_list = current_multipliers_str.split()
- lora_names_list = [Path(lora_path).name for lora_path in all_loras_for_model
- if Path(lora_path).name in current_activated and Path(lora_path).name != causvid_lora_basename]
-
- final_multipliers = []
- final_loras = []
-
- # Add CausVid first
- final_loras.append(causvid_lora_basename)
- final_multipliers.append("0.7")
-
- # Add existing, ensuring no duplicate multiplier for already present CausVid (though it shouldn't be)
- processed_other_loras = set()
- for i, lora_name in enumerate(current_activated):
- if lora_name == causvid_lora_basename: continue # Already handled
- if lora_name not in processed_other_loras:
- final_loras.append(lora_name)
- if i < len(multipliers_list):
- final_multipliers.append(multipliers_list[i])
- else:
- final_multipliers.append("1.0") # Default if not enough multipliers
- processed_other_loras.add(lora_name)
-
- ui_defaults["activated_loras"] = final_loras # ensure order matches multipliers
- ui_defaults["loras_multipliers"] = " ".join(final_multipliers)
- else:
- ui_defaults["loras_multipliers"] = "0.7"
- ui_defaults["activated_loras"] = [causvid_lora_basename] # ensure only causvid if no others
-
- state[model_type_key] = ui_defaults
- return state, ui_defaults
+ Returns:
+ True if LoRA exists or was successfully downloaded, False otherwise
+ """
+ target_path = target_dir / lora_name
+
+ # Check model requirements if specified
+ if model_filename and model_requirements:
+ requirements_met = all(req in model_filename.lower() for req in model_requirements)
+ if not requirements_met:
+ print(f"[WARNING Task ID: {task_id}] {lora_name} is intended for models with {model_requirements}. Current model is {model_filename}. Results may vary.")
+
+ # Download if not present
+ if not target_path.exists():
+ print(f"[Task ID: {task_id}] {lora_name} not found. Attempting download...")
+ if not download_file(lora_url, target_dir, lora_name):
+ print(f"[WARNING Task ID: {task_id}] Failed to download {lora_name}. Proceeding without it.")
+ task_params_dict[flag_key] = False
+ return False
+
+ return True
# -----------------------------------------------------------------------------
-# 5. Download utility
+# 4. State builder for a single task (same as before)
# -----------------------------------------------------------------------------
-def download_file(url, dest_folder, filename):
- dest_path = Path(dest_folder) / filename
- if dest_path.exists():
- print(f"[INFO] File {filename} already exists in {dest_folder}.")
- return True
- try:
- print(f"Downloading {filename} from {url} to {dest_folder}...")
- response = requests.get(url, stream=True)
- response.raise_for_status() # Raise an exception for HTTP errors
- dest_path.parent.mkdir(parents=True, exist_ok=True)
- with open(dest_path, 'wb') as f:
- for chunk in response.iter_content(chunk_size=8192):
- f.write(chunk)
- print(f"Successfully downloaded {filename}.")
- return True
- except Exception as e:
- print(f"[ERROR] Failed to download {filename}: {e}")
- if dest_path.exists(): # Attempt to clean up partial download
- try: os.remove(dest_path)
- except: pass
- return False
+# --- SM_RESTRUCTURE: This function has been moved to source/common_utils.py ---
+
# -----------------------------------------------------------------------------
# 6. Process a single task dictionary from the tasks.json list
# -----------------------------------------------------------------------------
-def process_single_task(wgp_mod, task_params_dict, main_output_dir_base: Path, task_type: str):
- dprint(f"PROCESS_SINGLE_TASK received task_params_dict: {json.dumps(task_params_dict)}") # DEBUG ADDED
+def process_single_task(wgp_mod, task_params_dict, main_output_dir_base: Path, task_type: str, project_id_for_task: str | None, image_download_dir: Path | str | None = None, apply_reward_lora: bool = False, colour_match_videos: bool = False, mask_active_frames: bool = True):
+ dprint(f"--- Entering process_single_task ---")
+ dprint(f"Task Type: {task_type}")
+ dprint(f"Project ID for task: {project_id_for_task}") # Added dprint for project_id
+ dprint(f"Task Params (first 1000 chars): {json.dumps(task_params_dict, default=str, indent=2)[:1000]}...")
task_id = task_params_dict.get("task_id", "unknown_task_" + str(time.time()))
print(f"--- Processing task ID: {task_id} of type: {task_type} ---")
output_location_to_db = None # Will store the final path/URL for the DB
+ generation_success = False
- # --- Check for new task type ---
- if task_type == "generate_openpose":
- if PoseBodyFaceVideoAnnotator is None:
- print(f"[ERROR Task ID: {task_id}] PoseBodyFaceVideoAnnotator not imported. Cannot process 'generate_openpose' task.")
- return False, "PoseBodyFaceVideoAnnotator module not available."
-
- print(f"[Task ID: {task_id}] Identified as 'generate_openpose' task.")
- return _handle_generate_openpose_task(task_params_dict, main_output_dir_base, task_id)
- elif task_type == "rife_interpolate_images":
- print(f"[Task ID: {task_id}] Identified as 'rife_interpolate_images' task.")
- return _handle_rife_interpolate_task(wgp_mod, task_params_dict, main_output_dir_base, task_id)
- elif task_type == "comfyui_workflow": # New task type for ComfyUI
- print(f"[Task ID: {task_id}] Identified as 'comfyui_workflow' task.")
- return _handle_comfyui_workflow_task(task_params_dict, main_output_dir_base, task_id)
- # --- SM_RESTRUCTURE: Add new travel task handlers ---
- elif task_type == "travel_orchestrator":
+ # --- Orchestrator & Self-Contained Task Handlers ---
+ # These tasks manage their own sub-task queuing and can return directly, as they
+ # are either the start of a chain or a self-contained unit.
+ if task_type == "travel_orchestrator":
print(f"[Task ID: {task_id}] Identified as 'travel_orchestrator' task.")
- return _handle_travel_orchestrator_task(task_params_dict, main_output_dir_base, task_id)
+ # Ensure the orchestrator uses the DB row ID as its canonical task_id
+ task_params_dict["task_id"] = task_id
+ if "orchestrator_details" in task_params_dict:
+ task_params_dict["orchestrator_details"]["orchestrator_task_id"] = task_id
+ return tbi._handle_travel_orchestrator_task(task_params_from_db=task_params_dict, main_output_dir_base=main_output_dir_base, orchestrator_task_id_str=task_id, orchestrator_project_id=project_id_for_task, dprint=dprint)
elif task_type == "travel_segment":
print(f"[Task ID: {task_id}] Identified as 'travel_segment' task.")
- # This will call wgp_mod like a standard task but might have pre/post processing
- # based on orchestrator details passed in its params.
- return _handle_travel_segment_task(wgp_mod, task_params_dict, main_output_dir_base, task_id)
+ return tbi._handle_travel_segment_task(wgp_mod, task_params_dict, main_output_dir_base, task_id, apply_reward_lora, colour_match_videos, mask_active_frames, process_single_task=process_single_task, dprint=dprint)
elif task_type == "travel_stitch":
print(f"[Task ID: {task_id}] Identified as 'travel_stitch' task.")
- return _handle_travel_stitch_task(task_params_dict, main_output_dir_base, task_id)
- # --- End SM_RESTRUCTURE ---
-
- # Default handling for standard wgp tasks (original logic)
- task_model_type_logical = task_params_dict.get("model", "t2v")
- # Determine the actual model filename before checking/downloading LoRA, as LoRA path depends on it
- model_filename_for_task = wgp_mod.get_model_filename(task_model_type_logical,
- wgp_mod.transformer_quantization,
- wgp_mod.transformer_dtype_policy)
-
- use_causvid = task_params_dict.get("use_causvid_lora", False)
- causvid_lora_basename = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
- causvid_lora_url = "https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
-
- if use_causvid:
- base_lora_dir_for_model = Path(wgp_mod.get_lora_dir(model_filename_for_task))
- target_causvid_lora_dir = base_lora_dir_for_model
-
- if "14B" in model_filename_for_task and "t2v" in model_filename_for_task.lower():
- pass
- elif "14B" in model_filename_for_task:
- pass
-
- if not Path(target_causvid_lora_dir / causvid_lora_basename).exists():
- print(f"[Task ID: {task_id}] CausVid LoRA not found. Attempting download...")
- if not download_file(causvid_lora_url, target_causvid_lora_dir, causvid_lora_basename):
- print(f"[WARNING Task ID: {task_id}] Failed to download CausVid LoRA. Proceeding without it or with default settings.")
- task_params_dict["use_causvid_lora"] = False
- else:
- pass
- if not "14B" in model_filename_for_task or not "t2v" in model_filename_for_task.lower():
- print(f"[WARNING Task ID: {task_id}] CausVid LoRA is intended for 14B T2V models. Current model is {model_filename_for_task}. Results may vary.")
-
- print(f"[Task ID: {task_id}] Using model file: {model_filename_for_task}")
-
- # Use a temporary directory for wgp.py to save its output
- # This temporary directory will be created under the system's temp location
- temp_output_dir = tempfile.mkdtemp(prefix=f"wgp_headless_{task_id}_")
- dprint(f"[Task ID: {task_id}] Using temporary output directory: {temp_output_dir}")
-
- original_wgp_save_path = wgp_mod.save_path
- wgp_mod.save_path = str(temp_output_dir) # wgp.py saves here
-
- lora_dir_for_active_model = wgp_mod.get_lora_dir(model_filename_for_task)
- all_loras_for_active_model, _, _, _, _, _, _ = wgp_mod.setup_loras(
- model_filename_for_task, None, lora_dir_for_active_model, "", None
- )
+ return tbi._handle_travel_stitch_task(task_params_from_db=task_params_dict, main_output_dir_base=main_output_dir_base, stitch_task_id_str=task_id, dprint=dprint)
+ elif task_type == "different_perspective_orchestrator":
+ print(f"[Task ID: {task_id}] Identified as 'different_perspective_orchestrator' task.")
+ return dp._handle_different_perspective_orchestrator_task(
+ task_params_from_db=task_params_dict,
+ main_output_dir_base=main_output_dir_base,
+ orchestrator_task_id_str=task_id,
+ dprint=dprint
+ )
+ elif task_type == "dp_final_gen":
+ print(f"[Task ID: {task_id}] Identified as 'dp_final_gen' task.")
+ return dp._handle_dp_final_gen_task(
+ wgp_mod=wgp_mod,
+ main_output_dir_base=main_output_dir_base,
+ process_single_task=process_single_task,
+ task_params_from_db=task_params_dict,
+ dprint=dprint
+ )
+ elif task_type == "single_image":
+ print(f"[Task ID: {task_id}] Identified as 'single_image' task.")
+ return si._handle_single_image_task(
+ wgp_mod=wgp_mod,
+ task_params_from_db=task_params_dict,
+ main_output_dir_base=main_output_dir_base,
+ task_id=task_id,
+ image_download_dir=image_download_dir,
+ apply_reward_lora=apply_reward_lora,
+ dprint=dprint
+ )
+ elif task_type == "magic_edit":
+ print(f"[Task ID: {task_id}] Identified as 'magic_edit' task.")
+ return me._handle_magic_edit_task(
+ task_params_from_db=task_params_dict,
+ main_output_dir_base=main_output_dir_base,
+ task_id=task_id,
+ dprint=dprint
+ )
- state, ui_params = build_task_state(wgp_mod, model_filename_for_task, task_params_dict, all_loras_for_active_model)
-
- gen_task_placeholder = {"id": 1, "prompt": ui_params.get("prompt"), "params": {}}
- send_cmd = make_send_cmd(task_id)
+ # --- Primitive Task Execution Block ---
+ # These tasks (openpose, rife, wgp) might be part of a chain.
+ # They set generation_success and output_location_to_db, then execution
+ # falls through to the chaining logic at the end of this function.
+ if task_type == "generate_openpose":
+ print(f"[Task ID: {task_id}] Identified as 'generate_openpose' task.")
+ generation_success, output_location_to_db = handle_generate_openpose_task(task_params_dict, main_output_dir_base, task_id, dprint)
- tea_cache_value = ui_params.get("tea_cache_setting", ui_params.get("tea_cache", 0.0))
+ elif task_type == "extract_frame":
+ print(f"[Task ID: {task_id}] Identified as 'extract_frame' task.")
+ generation_success, output_location_to_db = handle_extract_frame_task(task_params_dict, main_output_dir_base, task_id, dprint)
- print(f"[Task ID: {task_id}] Starting generation with effective params: {json.dumps(ui_params, default=lambda o: 'Unserializable' if isinstance(o, Image.Image) else o.__dict__ if hasattr(o, '__dict__') else str(o), indent=2)}")
- generation_success = False
+ elif task_type == "rife_interpolate_images":
+ print(f"[Task ID: {task_id}] Identified as 'rife_interpolate_images' task.")
+ generation_success, output_location_to_db = handle_rife_interpolate_task(wgp_mod, task_params_dict, main_output_dir_base, task_id, dprint)
- # --- Determine frame_num for wgp.py ---
- requested_frames_from_task = ui_params.get("video_length", 81)
- if requested_frames_from_task == 1:
- frame_num_for_wgp = 1
- dprint(f"[Task ID: {task_id}] Single-frame request detected, passing through (1 frame).")
- elif requested_frames_from_task <= 4:
- frame_num_for_wgp = 5 # wgp.py drops to 1 otherwise; use 5 to preserve motion guidance
- dprint(f"[Task ID: {task_id}] Small-frame ({requested_frames_from_task}) request adjusted to 5 frames for wgp.py compatibility.")
+ # Default handling for standard wgp tasks (original logic)
else:
- frame_num_for_wgp = (requested_frames_from_task // 4) * 4 + 1
- dprint(f"[Task ID: {task_id}] Calculated frame count for wgp.py: {frame_num_for_wgp} (requested: {requested_frames_from_task})")
- # --- End frame_num determination ---
+ task_model_type_logical = task_params_dict.get("model", "t2v")
+ model_filename_for_task = wgp_mod.get_model_filename(task_model_type_logical,
+ wgp_mod.transformer_quantization,
+ wgp_mod.transformer_dtype_policy)
+ custom_output_dir = task_params_dict.get("output_dir")
+
+ if apply_reward_lora:
+ print(f"[Task ID: {task_id}] --apply-reward-lora flag is active. Checking and downloading reward LoRA.")
+ # Use shared helper (reward LoRA download never fails)
+ reward_url = "https://huggingface.co/peteromallet/Wan2.1-Fun-14B-InP-MPS_reward_lora_diffusers/resolve/main/Wan2.1-Fun-14B-InP-MPS_reward_lora_wgp.safetensors"
+ _ensure_lora_downloaded(task_id, "Wan2.1-Fun-14B-InP-MPS_reward_lora_wgp.safetensors",
+ reward_url, Path(wgp_mod.get_lora_dir(model_filename_for_task)), {}, "dummy_key")
- ui_params["video_length"] = frame_num_for_wgp
+ effective_image_download_dir = image_download_dir
- try:
- wgp_mod.generate_video(
- task=gen_task_placeholder, send_cmd=send_cmd,
- prompt=ui_params["prompt"],
- negative_prompt=ui_params.get("negative_prompt", ""),
- resolution=ui_params["resolution"],
- video_length=ui_params.get("video_length", 81),
- seed=ui_params["seed"],
- num_inference_steps=ui_params.get("num_inference_steps", 30),
- guidance_scale=ui_params.get("guidance_scale", 5.0),
- audio_guidance_scale=ui_params.get("audio_guidance_scale", 5.0),
- flow_shift=ui_params.get("flow_shift", wgp_mod.get_default_flow(model_filename_for_task, wgp_mod.test_class_i2v(model_filename_for_task))),
- embedded_guidance_scale=ui_params.get("embedded_guidance_scale", 6.0),
- repeat_generation=ui_params.get("repeat_generation", 1),
- multi_images_gen_type=ui_params.get("multi_images_gen_type", 0),
- tea_cache_setting=tea_cache_value,
- tea_cache_start_step_perc=ui_params.get("tea_cache_start_step_perc", 0),
- activated_loras=ui_params.get("activated_loras", []),
- loras_multipliers=ui_params.get("loras_multipliers", ""),
- image_prompt_type=ui_params.get("image_prompt_type", "T"),
- image_start=[wgp_mod.convert_image(img) for img in ui_params.get("image_start", [])],
- image_end=[wgp_mod.convert_image(img) for img in ui_params.get("image_end", [])],
- model_mode=ui_params.get("model_mode", 0),
- video_source=ui_params.get("video_source", None),
- keep_frames_video_source=ui_params.get("keep_frames_video_source", ""),
-
- video_prompt_type=ui_params.get("video_prompt_type", ""),
- image_refs=ui_params.get("image_refs", None),
- video_guide=ui_params.get("video_guide", None),
- keep_frames_video_guide=ui_params.get("keep_frames_video_guide", ""),
- video_mask=ui_params.get("video_mask", None),
- audio_guide=ui_params.get("audio_guide", None),
- sliding_window_size=ui_params.get("sliding_window_size", 81),
- sliding_window_overlap=ui_params.get("sliding_window_overlap", 5),
- sliding_window_overlap_noise=ui_params.get("sliding_window_overlap_noise", 20),
- sliding_window_discard_last_frames=ui_params.get("sliding_window_discard_last_frames", 0),
- remove_background_images_ref=ui_params.get("remove_background_images_ref", ui_params.get("remove_background_image_ref", 1)),
- temporal_upsampling=ui_params.get("temporal_upsampling", ""),
- spatial_upsampling=ui_params.get("spatial_upsampling", ""),
- RIFLEx_setting=ui_params.get("RIFLEx_setting", 0),
- slg_switch=ui_params.get("slg_switch", 0),
- slg_layers=ui_params.get("slg_layers", [9]),
- slg_start_perc=ui_params.get("slg_start_perc", 10),
- slg_end_perc=ui_params.get("slg_end_perc", 90),
- cfg_star_switch=ui_params.get("cfg_star_switch", 0),
- cfg_zero_step=ui_params.get("cfg_zero_step", -1),
- prompt_enhancer=ui_params.get("prompt_enhancer", ""),
- state=state,
- model_filename=model_filename_for_task
- )
- print(f"[Task ID: {task_id}] Generation completed to temporary directory: {wgp_mod.save_path}")
- generation_success = True
- except Exception as e:
- print(f"[ERROR] Task ID {task_id} failed during generation: {e}")
- traceback.print_exc()
- # generation_success remains False
- finally:
- wgp_mod.save_path = original_wgp_save_path # Restore original save path
+ if effective_image_download_dir is None:
+ if db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH:
+ try:
+ sqlite_db_path_obj = Path(db_ops.SQLITE_DB_PATH).resolve()
+ if sqlite_db_path_obj.is_file():
+ sqlite_db_parent_dir = sqlite_db_path_obj.parent
+ candidate_download_dir = sqlite_db_parent_dir / "public" / "data" / "image_downloads" / task_id
+ candidate_download_dir.mkdir(parents=True, exist_ok=True)
+ effective_image_download_dir = str(candidate_download_dir.resolve())
+ dprint(f"Task {task_id}: Determined SQLite-based image_download_dir for standard task: {effective_image_download_dir}")
+ else:
+ dprint(f"Task {task_id}: SQLITE_DB_PATH '{db_ops.SQLITE_DB_PATH}' is not a file. Cannot determine parent for image_download_dir.")
+ except Exception as e_idir_sqlite:
+ dprint(f"Task {task_id}: Could not create SQLite-based image_download_dir for standard task: {e_idir_sqlite}.")
- if generation_success:
- # Find the generated video file (assuming one .mp4)
- generated_video_file = None
- for item in Path(temp_output_dir).iterdir():
- if item.is_file() and item.suffix.lower() == ".mp4":
- generated_video_file = item
- break
+ # Handle special LoRA downloads using shared helper
+ base_lora_dir_for_model = Path(wgp_mod.get_lora_dir(model_filename_for_task))
- if generated_video_file:
- dprint(f"[Task ID: {task_id}] Found generated video: {generated_video_file}")
- if DB_TYPE == "sqlite":
- dprint(f"HEADLESS SQLITE SAVE: task_params_dict contains output_path? Key: 'output_path', Value: {task_params_dict.get('output_path')}") # DEBUG ADDED
- custom_output_path_str = task_params_dict.get("output_path")
- if custom_output_path_str:
- # `output_path` is expected to be a full path (absolute or relative). If relative, resolve against cwd.
- final_video_path = Path(custom_output_path_str).expanduser().resolve()
- final_video_path.parent.mkdir(parents=True, exist_ok=True)
- else:
- output_sub_dir_val = task_params_dict.get("output_sub_dir", task_id)
- output_sub_dir_path = Path(output_sub_dir_val)
- # If the provided sub-dir is absolute, use it directly; otherwise nest it under the main output dir.
- if output_sub_dir_path.is_absolute():
- final_output_dir = output_sub_dir_path
- else:
- final_output_dir = main_output_dir_base / output_sub_dir_path
- final_output_dir.mkdir(parents=True, exist_ok=True)
- final_video_path = final_output_dir / f"{task_id}.mp4"
+ if task_params_dict.get("use_causvid_lora", False):
+ _ensure_lora_downloaded(
+ task_id, "Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
+ "https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
+ base_lora_dir_for_model, task_params_dict, "use_causvid_lora",
+ model_filename_for_task, ["14b", "t2v"]
+ )
+
+ if task_params_dict.get("use_lighti2x_lora", False):
+ _ensure_lora_downloaded(
+ task_id, "wan_lcm_r16_fp32_comfy.safetensors",
+ "https://huggingface.co/peteromallet/ad_motion_loras/resolve/main/wan_lcm_r16_fp32_comfy.safetensors",
+ base_lora_dir_for_model, task_params_dict, "use_lighti2x_lora"
+ )
+
+ additional_loras = task_params_dict.get("additional_loras", {})
+ if additional_loras:
+ dprint(f"[Task ID: {task_id}] Processing additional LoRAs: {additional_loras}")
+ processed_loras = {}
+ for lora_path, lora_strength in additional_loras.items():
try:
- shutil.move(str(generated_video_file), str(final_video_path))
- output_location_to_db = str(final_video_path.resolve())
- print(f"[Task ID: {task_id}] Output video saved to: {output_location_to_db}")
-
- # If a custom output_path was used, there might still be a default
- # directory (e.g., /) that was created
- # earlier by wgp.py or previous logic. Clean it up to avoid clutter.
- if custom_output_path_str:
- default_dir_to_clean = (main_output_dir_base / task_id)
- try:
- if default_dir_to_clean.exists() and default_dir_to_clean.is_dir():
- shutil.rmtree(default_dir_to_clean)
- dprint(f"[Task ID: {task_id}] Removed default output directory that is no longer needed: {default_dir_to_clean}")
- except Exception as e_cleanup_default:
- print(f"[WARNING Task ID: {task_id}] Could not remove default output directory {default_dir_to_clean}: {e_cleanup_default}")
- except Exception as e_move:
- print(f"[ERROR Task ID: {task_id}] Failed to move video to final local destination: {e_move}")
- generation_success = False # Mark as failed if file handling fails
+ lora_filename = Path(urllib.parse.urlparse(lora_path).path).name
+
+ if lora_path.startswith("http://") or lora_path.startswith("https://"):
+ wan2gp_loras_dir = Path("Wan2GP/loras")
+ wan2gp_loras_dir.mkdir(parents=True, exist_ok=True)
+ dprint(f"Task {task_id}: Downloading LoRA from {lora_path} to {wan2gp_loras_dir}")
+ download_file(lora_path, wan2gp_loras_dir, lora_filename)
+
+ downloaded_lora_path = wan2gp_loras_dir / lora_filename
+ target_path = base_lora_dir_for_model / lora_filename
+ if not target_path.exists() or downloaded_lora_path.resolve() != target_path.resolve():
+ shutil.copy(str(downloaded_lora_path), str(target_path))
+ dprint(f"Copied downloaded LoRA from {downloaded_lora_path} to {target_path}")
+ else:
+ dprint(f"Downloaded LoRA already exists at target: {target_path}")
+ else:
+ source_path = Path(lora_path)
+ target_path = base_lora_dir_for_model / lora_filename
+
+ if source_path.is_absolute() and source_path.exists():
+ if not target_path.exists() or source_path.resolve() != target_path.resolve():
+ shutil.copy(str(source_path), str(target_path))
+ dprint(f"Copied local LoRA from {source_path} to {target_path}")
+ else:
+ dprint(f"LoRA already exists at target: {target_path}")
+ elif (base_lora_dir_for_model / lora_path).exists():
+ existing_lora_path = base_lora_dir_for_model / lora_path
+ lora_filename = existing_lora_path.name
+ dprint(f"Found existing LoRA in directory: {existing_lora_path}")
+ elif (Path("Wan2GP/loras") / lora_path).exists():
+ wan2gp_lora_path = Path("Wan2GP/loras") / lora_path
+ target_path = base_lora_dir_for_model / wan2gp_lora_path.name
+ if not target_path.exists() or wan2gp_lora_path.resolve() != target_path.resolve():
+ shutil.copy(str(wan2gp_lora_path), str(target_path))
+ dprint(f"Copied LoRA from Wan2GP/loras: {wan2gp_lora_path} to {target_path}")
+ else:
+ dprint(f"LoRA from Wan2GP/loras already exists at target: {target_path}")
+ lora_filename = wan2gp_lora_path.name
+ elif source_path.exists():
+ if not target_path.exists() or source_path.resolve() != target_path.resolve():
+ shutil.copy(str(source_path), str(target_path))
+ dprint(f"Copied local LoRA from {source_path} to {target_path}")
+ else:
+ dprint(f"LoRA already exists at target: {target_path}")
+ else:
+ dprint(f"[WARNING Task ID: {task_id}] LoRA not found at any location: '{lora_path}'. Checked: full path, LoRA directory, and relative path. Skipping.")
+ continue
+
+ processed_loras[lora_filename] = str(lora_strength)
+ except Exception as e_lora:
+ print(f"[ERROR Task ID: {task_id}] Failed to process additional LoRA {lora_path}: {e_lora}")
- elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
- # For Supabase, use task_id as part of the object name to ensure uniqueness
- # You might want to include the original filename if it's meaningful
-
- # URL-encode the filename part to handle spaces and special characters
- encoded_file_name = urllib.parse.quote(generated_video_file.name)
- object_name = f"{task_id}/{encoded_file_name}"
- # Or if you wanted task_id as the filename: object_name = f"{task_id}.mp4"
- # Or if you wanted just the encoded filename: object_name = encoded_file_name
- dprint(f"[Task ID: {task_id}] Original filename: {generated_video_file.name}")
- dprint(f"[Task ID: {task_id}] Encoded filename for Supabase object: {encoded_file_name}")
- dprint(f"[Task ID: {task_id}] Final Supabase object_name: {object_name}")
-
- public_url = upload_to_supabase_storage(generated_video_file, object_name, SUPABASE_VIDEO_BUCKET)
- if public_url:
- output_location_to_db = public_url
- else:
- print(f"[WARNING Task ID: {task_id}] Supabase upload failed or no URL returned. No output location will be saved.")
- generation_success = False # Mark as failed if upload fails
- else:
- print(f"[WARNING Task ID: {task_id}] Output generated but DB_TYPE ({DB_TYPE}) is not sqlite or Supabase client is not configured for upload.")
- generation_success = False # Cannot determine final resting place
- else:
- print(f"[WARNING Task ID: {task_id}] Generation reported success, but no .mp4 file found in {temp_output_dir}")
- generation_success = False # No output to save
-
- # Clean up the temporary directory
- try:
- shutil.rmtree(temp_output_dir)
- dprint(f"[Task ID: {task_id}] Cleaned up temporary directory: {temp_output_dir}")
- except Exception as e_clean:
- print(f"[WARNING Task ID: {task_id}] Failed to clean up temporary directory {temp_output_dir}: {e_clean}")
+ task_params_dict["processed_additional_loras"] = processed_loras
- print(f"--- Finished task ID: {task_id} (Success: {generation_success}) ---")
- return generation_success, output_location_to_db
+ print(f"[Task ID: {task_id}] Using model file: {model_filename_for_task}")
-# -----------------------------------------------------------------------------
-# +++. New function to handle 'generate_openpose' task
-# -----------------------------------------------------------------------------
-def _handle_generate_openpose_task(task_params_dict: dict, main_output_dir_base: Path, task_id: str):
- """
- Handles the 'generate_openpose' task type.
- Generates an OpenPose image from an input image path specified in task_params_dict.
- Saves the OpenPose image to the output_path also specified in task_params_dict.
- """
- print(f"[Task ID: {task_id}] Handling 'generate_openpose' task.")
- input_image_path_str = task_params_dict.get("input_image_path")
- output_image_path_str = task_params_dict.get("output_path") # Expecting full path from caller
-
- if not input_image_path_str:
- print(f"[ERROR Task ID: {task_id}] 'input_image_path' not specified for generate_openpose task.")
- return False, "Missing input_image_path"
-
- if not output_image_path_str:
- # Fallback if not specified, though steerable_motion.py should always provide it
- default_output_dir = main_output_dir_base / task_id
- default_output_dir.mkdir(parents=True, exist_ok=True)
- output_image_path = default_output_dir / f"{task_id}_openpose.png"
- print(f"[WARNING Task ID: {task_id}] 'output_path' not specified. Defaulting to {output_image_path}")
- else:
- output_image_path = Path(output_image_path_str)
+ temp_output_dir = tempfile.mkdtemp(prefix=f"wgp_headless_{task_id}_")
+ dprint(f"[Task ID: {task_id}] Using temporary output directory: {temp_output_dir}")
- input_image_path = Path(input_image_path_str)
- output_image_path.parent.mkdir(parents=True, exist_ok=True)
+ original_wgp_save_path = wgp_mod.save_path
+ wgp_mod.save_path = str(temp_output_dir)
- if not input_image_path.is_file():
- print(f"[ERROR Task ID: {task_id}] Input image file not found: {input_image_path}")
- return False, f"Input image not found: {input_image_path}"
+ lora_dir_for_active_model = wgp_mod.get_lora_dir(model_filename_for_task)
+ all_loras_for_active_model, _, _, _, _, _, _ = wgp_mod.setup_loras(
+ model_filename_for_task, None, lora_dir_for_active_model, "", None
+ )
- try:
- pil_input_image = Image.open(input_image_path).convert("RGB")
-
- # Config for Pose Annotator (similar to wgp.py)
- # Ensure these ckpt paths are accessible relative to where headless.py runs
- pose_cfg_dict = {
- "DETECTION_MODEL": "ckpts/pose/yolox_l.onnx",
- "POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
- "RESIZE_SIZE": 1024 # Internal resize for detection, not necessarily final output size
- }
- if PoseBodyFaceVideoAnnotator is None: # Should have been caught earlier, but double check
- raise ImportError("PoseBodyFaceVideoAnnotator could not be imported.")
-
- pose_annotator = PoseBodyFaceVideoAnnotator(pose_cfg_dict)
+ state, ui_params = build_task_state(wgp_mod, model_filename_for_task, task_params_dict, all_loras_for_active_model, image_download_dir, apply_reward_lora=apply_reward_lora)
- # The forward method expects a list of PIL images
- # It returns a list of NumPy arrays (H, W, C) in BGR format by default from dwpose
- openpose_np_frames_bgr = pose_annotator.forward([pil_input_image])
+ gen_task_placeholder = {"id": 1, "prompt": ui_params.get("prompt"), "params": {"model_filename_from_gui_state": model_filename_for_task, "model": task_model_type_logical}}
+ send_cmd = make_send_cmd(task_id)
- if not openpose_np_frames_bgr or openpose_np_frames_bgr[0] is None:
- print(f"[ERROR Task ID: {task_id}] OpenPose generation failed or returned no frame.")
- return False, "OpenPose generation returned no data."
-
- openpose_np_frame_bgr = openpose_np_frames_bgr[0]
-
- # PoseBodyFaceVideoAnnotator output is BGR, convert to RGB for PIL save if needed
- # However, PIL Image.fromarray can often handle BGR if mode is not specified,
- # but explicitly converting is safer for PNG.
- # Or, let's check if the annotator itself returns RGB.
- # dwpose.annotator.py > draw_pose seems to draw on an RGB copy.
- # Let's assume `anno_ins.forward` gives RGB-compatible array or directly RGB.
- # If colors are inverted, we'll need: openpose_np_frame_rgb = cv2.cvtColor(openpose_np_frame_bgr, cv2.COLOR_BGR2RGB)
- # For now, let's assume it's directly usable by PIL.
-
- openpose_pil_image = Image.fromarray(openpose_np_frame_bgr.astype(np.uint8)) # Ensure uint8
- openpose_pil_image.save(output_image_path)
+ if "resolution" in ui_params:
+ try:
+ width, height = map(int, ui_params["resolution"].split("x"))
+ new_width, new_height = (width // 16) * 16, (height // 16) * 16
+ ui_params["resolution"] = f"{new_width}x{new_height}"
+ dprint(f"Adjusted resolution in ui_params to {ui_params['resolution']}")
+ except Exception as err:
+ dprint(f"Error adjusting resolution: {err}")
- print(f"[Task ID: {task_id}] Successfully generated OpenPose image to: {output_image_path.resolve()}")
- return True, str(output_image_path.resolve())
+ tea_cache_value = ui_params.get("tea_cache_setting", ui_params.get("tea_cache", 0.0))
- except ImportError as ie:
- print(f"[ERROR Task ID: {task_id}] Import error during OpenPose generation: {ie}. Ensure 'preprocessing' module is in PYTHONPATH and dependencies are installed.")
- traceback.print_exc()
- return False, f"Import error: {ie}"
- except FileNotFoundError as fnfe: # For missing ONNX models
- print(f"[ERROR Task ID: {task_id}] ONNX model file not found for OpenPose: {fnfe}. Ensure 'ckpts/pose/*' models are present.")
- traceback.print_exc()
- return False, f"ONNX model not found: {fnfe}"
- except Exception as e:
- print(f"[ERROR Task ID: {task_id}] Failed during OpenPose image generation: {e}")
- traceback.print_exc()
- return False, f"OpenPose generation exception: {e}"
+ print(f"[Task ID: {task_id}] Starting generation with effective params: {json.dumps(ui_params, default=lambda o: 'Unserializable' if isinstance(o, Image.Image) else o.__dict__ if hasattr(o, '__dict__') else str(o), indent=2)}")
-# -----------------------------------------------------------------------------
-# +++. New function to handle 'rife_interpolate_images' task
-# -----------------------------------------------------------------------------
-def _handle_rife_interpolate_task(wgp_mod, task_params_dict: dict, main_output_dir_base: Path, task_id: str):
- """
- Handles the 'rife_interpolate_images' task type using wgp.py's RIFE capabilities.
- """
- print(f"[Task ID: {task_id}] Handling 'rife_interpolate_images' task.")
-
- input_image_path1_str = task_params_dict.get("input_image_path1")
- input_image_path2_str = task_params_dict.get("input_image_path2")
- output_video_path_str = task_params_dict.get("output_path")
- num_rife_frames = task_params_dict.get("frames")
- resolution_str = task_params_dict.get("resolution") # e.g., "960x544"
- # Default model for RIFE context, actual RIFE process might be model-agnostic in wgp.py
- # Using a known model type like t2v for get_model_filename
- default_model_for_context = task_params_dict.get("model_context_for_rife", "vace_14B")
-
- # Validate required parameters
- required_params = {
- "input_image_path1": input_image_path1_str,
- "input_image_path2": input_image_path2_str,
- "output_path": output_video_path_str,
- "frames": num_rife_frames,
- "resolution": resolution_str
- }
- missing_params = [key for key, value in required_params.items() if value is None]
- if missing_params:
- error_msg = f"Missing required parameters for rife_interpolate_images: {', '.join(missing_params)}"
- print(f"[ERROR Task ID: {task_id}] {error_msg}")
- return False, error_msg
-
- input_image1_path = Path(input_image_path1_str)
- input_image2_path = Path(input_image_path2_str)
- output_video_path = Path(output_video_path_str)
- output_video_path.parent.mkdir(parents=True, exist_ok=True)
-
- generation_success = False # Initialize
- output_location_to_db = None # Initialize
-
- dprint(f"[Task ID: {task_id}] Checking input image paths.")
- if not input_image1_path.is_file():
- print(f"[ERROR Task ID: {task_id}] Input image 1 not found: {input_image1_path}")
- return False, f"Input image 1 not found: {input_image1_path}"
- if not input_image2_path.is_file():
- print(f"[ERROR Task ID: {task_id}] Input image 2 not found: {input_image2_path}")
- return False, f"Input image 2 not found: {input_image2_path}"
- dprint(f"[Task ID: {task_id}] Input images found.")
-
- # Create a temporary directory for wgp.py output (though not used by direct RIFE, kept for structure)
- temp_output_dir = tempfile.mkdtemp(prefix=f"wgp_rife_{task_id}_")
- original_wgp_save_path = wgp_mod.save_path
- wgp_mod.save_path = str(temp_output_dir)
-
- try:
- pil_image_start = Image.open(input_image1_path).convert("RGB")
- pil_image_end = Image.open(input_image2_path).convert("RGB")
-
- # Get a valid model_filename for wgp.py context (RIFE might not use its weights)
- # Assuming default transformer_quantization and transformer_dtype_policy from wgp_mod
- actual_model_filename = wgp_mod.get_model_filename(
- default_model_for_context, # e.g., "vace_14B"
- wgp_mod.transformer_quantization,
- wgp_mod.transformer_dtype_policy
- )
- dprint(f"[Task ID: {task_id}] Using model file for RIFE context: {actual_model_filename}")
+ requested_frames_from_task = ui_params.get("video_length", 81)
+ frame_num_for_wgp = requested_frames_from_task
+ print(f"[HEADLESS_DEBUG] Task {task_id}: FRAME COUNT ANALYSIS")
+ print(f"[HEADLESS_DEBUG] requested_frames_from_task: {requested_frames_from_task}")
+ print(f"[HEADLESS_DEBUG] frame_num_for_wgp: {frame_num_for_wgp}")
+ print(f"[HEADLESS_DEBUG] ui_params video_length: {ui_params.get('video_length')}")
+ dprint(f"[Task ID: {task_id}] Using requested frame count: {frame_num_for_wgp}")
- # Prepare minimal state for wgp.py
- # LoRA setup for the context model (RIFE likely won't use these LoRAs)
- lora_dir_for_context_model = wgp_mod.get_lora_dir(actual_model_filename)
- all_loras_for_context_model, _, _, _, _, _, _ = wgp_mod.setup_loras(
- actual_model_filename, None, lora_dir_for_context_model, "", None
- )
-
- # Basic state, RIFE might not need detailed model-specific ui_params
- # The key is that `generate_video` has a valid `model_filename` in `state`
- # and `state[wgp_mod.get_model_type(actual_model_filename)]` exists.
- model_type_key = wgp_mod.get_model_type(actual_model_filename)
- minimal_ui_defaults_for_model_type = wgp_mod.get_default_settings(actual_model_filename).copy()
- # Override crucial RIFE params in the minimal_ui_defaults for clarity if generate_video uses them from here
- minimal_ui_defaults_for_model_type["resolution"] = resolution_str
- minimal_ui_defaults_for_model_type["video_length"] = int(num_rife_frames)
-
-
- state = {
- "model_filename": actual_model_filename,
- "loras": all_loras_for_context_model, # Provide loras for the context model
- model_type_key: minimal_ui_defaults_for_model_type, # Default settings for the context model type
- "gen": {"queue": [], "file_list": [], "file_settings_list": [], "prompt_no": 1, "prompts_max": 1}
- }
-
- print(f"[Task ID: {task_id}] Starting direct RIFE interpolation (bypassing wgp.py).")
- dprint(f" Input 1: {input_image1_path}")
- dprint(f" Input 2: {input_image2_path}")
-
- # ---- Direct RIFE Implementation ----
- import torch
- import numpy as np
- from rife.inference import temporal_interpolation
- dprint(f"[Task ID: {task_id}] Imported RIFE modules.")
-
- width_out, height_out = map(int, resolution_str.split("x"))
- dprint(f"[Task ID: {task_id}] Parsed resolution: {width_out}x{height_out}")
-
- def pil_to_tensor_rgb_norm(pil_im: Image.Image):
- pil_resized = pil_im.resize((width_out, height_out), Image.Resampling.LANCZOS)
- np_rgb = np.asarray(pil_resized).astype(np.float32) / 127.5 - 1.0 # [0,255]->[-1,1]
- tensor = torch.from_numpy(np_rgb).permute(2, 0, 1) # C H W
- return tensor
-
- t_start = pil_to_tensor_rgb_norm(pil_image_start)
- t_end = pil_to_tensor_rgb_norm(pil_image_end)
-
- sample_in = torch.stack([t_start, t_end], dim=1).unsqueeze(0) # 1 x 3 x 2 x H x W
-
- device_for_rife = "cuda" if torch.cuda.is_available() else "cpu"
- sample_in = sample_in.to(device_for_rife)
- dprint(f"[Task ID: {task_id}] Input tensor for RIFE prepared on device: {device_for_rife}, shape: {sample_in.shape}")
-
- exp_val = 3 # x8 (2^3 + 1 = 9 frames output by this RIFE implementation for 2 inputs)
- flownet_ckpt = os.path.join("ckpts", "flownet.pkl")
- dprint(f"[Task ID: {task_id}] Checking for RIFE model: {flownet_ckpt}")
- if not os.path.exists(flownet_ckpt):
- error_msg_flownet = f"RIFE Error: flownet.pkl not found at {flownet_ckpt}"
- print(f"[ERROR Task ID: {task_id}] {error_msg_flownet}")
- # generation_success remains False, will be returned
- return False, error_msg_flownet # Explicitly return
- dprint(f"[Task ID: {task_id}] RIFE model found: {flownet_ckpt}. Exp_val: {exp_val}")
-
- # Remove batch dimension for rife.inference.temporal_interpolation's internal process_frames
- sample_in_for_rife = sample_in[0] # Shape: C x 2 x H x W
+ ui_params["video_length"] = frame_num_for_wgp
try:
- # sample_out_from_rife will have shape C x F_out x H x W
- dprint(f"[Task ID: {task_id}] Calling temporal_interpolation with input shape: {sample_in_for_rife.shape}")
- sample_out_from_rife = temporal_interpolation(flownet_ckpt, sample_in_for_rife, exp_val, device=device_for_rife)
- dprint(f"[Task ID: {task_id}] temporal_interpolation call completed.")
- if sample_out_from_rife is not None:
- dprint(f"[Task ID: {task_id}] RIFE output tensor shape: {sample_out_from_rife.shape}")
- else:
- dprint(f"[Task ID: {task_id}] RIFE process returned None for sample_out_from_rife.")
- except Exception as e_rife: # Inner catch for temporal_interpolation
- print(f"[ERROR Task ID: {task_id}] RIFE interpolation failed: {e_rife}")
- traceback.print_exc()
- generation_success = False
- sample_out_from_rife = None
-
- if sample_out_from_rife is not None:
- sample_out_no_batch = sample_out_from_rife.to("cpu") # Shape C x F_out x H x W
- total_frames_generated = sample_out_no_batch.shape[1] # F_out is at index 1
- dprint(f"[Task ID: {task_id}] RIFE produced {total_frames_generated} frames.")
-
- # Trim / pad to desired num_rife_frames
- if total_frames_generated < num_rife_frames:
- print(f"[Task ID: {task_id}] Warning: RIFE produced {total_frames_generated} frames, expected {num_rife_frames}. Padding last frame.")
- pad_frames = num_rife_frames - total_frames_generated
- else:
- pad_frames = 0
-
- frames_list_np = []
- # Iterate up to the smaller of desired frames or actual RIFE output frames
- for idx in range(min(num_rife_frames, total_frames_generated)):
- frame_tensor = sample_out_no_batch[:, idx] # Shape: C x H x W
- frame_np = ((frame_tensor.permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8) # H W C RGB
- frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
- frames_list_np.append(frame_bgr)
-
- # Pad with last frame if RIFE produced fewer frames than desired *and* we have at least one frame
- if pad_frames > 0 and frames_list_np:
- last_frame_to_pad = frames_list_np[-1].copy()
- frames_list_np.extend([last_frame_to_pad for _ in range(pad_frames)])
- elif not frames_list_np and num_rife_frames > 0: # RIFE produced 0 frames, or processing failed to create any
- dprint(f"[Task ID: {task_id}] Error: No frames available to write for RIFE video (num_rife_frames: {num_rife_frames}).")
- # generation_success remains False (from initialization)
-
- # Only proceed to write if we have frames in frames_list_np
- if frames_list_np:
- try:
- fps_output = 25 # Or use a parameter if available/needed
- dprint(f"[Task ID: {task_id}] Writing RIFE output video to: {output_video_path} with {len(frames_list_np)} frames.")
- output_video_path.parent.mkdir(parents=True, exist_ok=True)
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
- out_writer = cv2.VideoWriter(str(output_video_path), fourcc, float(fps_output), (width_out, height_out))
- if not out_writer.isOpened():
- # This will be caught by the outer broad 'except Exception as e'
- raise IOError(f"Failed to open VideoWriter for output RIFE video: {output_video_path}")
- for frame_np in frames_list_np:
- out_writer.write(frame_np)
- out_writer.release()
-
- if output_video_path.exists() and output_video_path.stat().st_size > 0:
- dprint(f"[Task ID: {task_id}] Video file confirmed exists and has size > 0.")
- generation_success = True # Set to True only after successful write and check
- output_location_to_db = str(output_video_path.resolve())
- print(f"[Task ID: {task_id}] Direct RIFE video saved to: {output_location_to_db}")
- else:
- print(f"[ERROR Task ID: {task_id}] RIFE output file missing or empty after writing attempt: {output_video_path}")
- # generation_success remains False (from initialization)
- except Exception as e_video_write:
- print(f"[ERROR Task ID: {task_id}] Exception during RIFE video writing: {e_video_write}")
- traceback.print_exc()
- # generation_success remains False (from initialization)
- # If frames_list_np was empty, generation_success remains False
- else: # sample_out_from_rife was None (RIFE call itself failed and was caught by e_rife)
- # generation_success remains False (error already printed by e_rife block)
- pass # dprint for this case is already after the e_rife block
-
- # ---- End Direct RIFE Implementation ----
-
- except Exception as e: # Broad catch-all for the whole RIFE block
- print(f"[ERROR Task ID: {task_id}] Overall _handle_rife_interpolate_task failed: {e}")
- traceback.print_exc()
- generation_success = False
- finally:
- wgp_mod.save_path = original_wgp_save_path # Restore original save path
-
- return generation_success, output_location_to_db
-
-# -----------------------------------------------------------------------------
-# +++. New function to handle 'comfyui_workflow' task
-# -----------------------------------------------------------------------------
-def _handle_comfyui_workflow_task(task_params_dict: dict, main_output_dir_base: Path, task_id: str):
- """
- Handles the 'comfyui_workflow' task type.
- Executes a ComfyUI workflow based on parameters in task_params_dict.
- """
- print(f"[Task ID: {task_id}] Handling 'comfyui_workflow' task.")
- generation_success = False
- output_location_to_db = None
-
- # --- 1. Extract ComfyUI specific parameters from task_params_dict ---
- comfyui_workflow_file_path = task_params_dict.get("comfyui_workflow_path")
- comfyui_server_address = task_params_dict.get("comfyui_server_address", "http://127.0.0.1:8188") # Default if not provided
- comfyui_inputs_for_workflow = task_params_dict.get("comfyui_inputs", {})
- # Example: client_id for ComfyUI API
- comfyui_client_id = str(uuid.uuid4())
-
- dprint(f"[Task ID: {task_id}] ComfyUI Workflow File: {comfyui_workflow_file_path}")
- dprint(f"[Task ID: {task_id}] ComfyUI Server: {comfyui_server_address}")
- dprint(f"[Task ID: {task_id}] ComfyUI Inputs: {json.dumps(comfyui_inputs_for_workflow, indent=2)}")
-
- if not comfyui_workflow_file_path:
- print(f"[ERROR Task ID: {task_id}] 'comfyui_workflow_path' not specified.")
- return False, "Missing comfyui_workflow_path"
-
- workflow_path = Path(comfyui_workflow_file_path)
- if not workflow_path.is_file():
- print(f"[ERROR Task ID: {task_id}] ComfyUI workflow file not found: {workflow_path}")
- return False, f"ComfyUI workflow file not found: {workflow_path}"
-
- # Create a temporary directory for any intermediate files ComfyUI might need or for its direct output
- # before we move it to the final location.
- temp_comfy_output_dir = tempfile.mkdtemp(prefix=f"comfy_headless_{task_id}_")
- dprint(f"[Task ID: {task_id}] Using temporary directory for ComfyUI outputs/staging: {temp_comfy_output_dir}")
-
- try:
- # --- 2. Load the ComfyUI workflow JSON ---
- # Placeholder: You'll need to read the JSON file content here.
- # Example:
- # with open(workflow_path, 'r') as f:
- # loaded_workflow_json = json.load(f)
- loaded_workflow_json = {} # Replace with actual loading
- dprint(f"[Task ID: {task_id}] Successfully loaded workflow (placeholder): {workflow_path.name}")
-
- # --- 3. Modify the workflow JSON with comfyui_inputs_for_workflow ---
- # Placeholder: This is the core logic for dynamically setting node inputs.
- # You will need to iterate through `comfyui_inputs_for_workflow` and
- # find the corresponding nodes in `loaded_workflow_json` to update their values.
- # Nodes are often identified by title (in _meta.title) or by class_type and index.
- # Example (highly conceptual):
- # for node_id, node_data in loaded_workflow_json.items(): # ComfyUI API format often has nodes keyed by string ID
- # node_title = node_data.get("_meta", {}).get("title")
- # if node_title and node_title in comfyui_inputs_for_workflow:
- # # Find the correct widget/input field and update it
- # # For a 'LoadImage' node, this might be node_data['inputs']['image'] = comfyui_inputs_for_workflow[node_title]
- # # For a 'CLIPTextEncode' (prompt) node, node_data['inputs']['text'] = comfyui_inputs_for_workflow[node_title]
- # dprint(f"[Task ID: {task_id}] Updated node '{node_title}' in workflow.")
- # pass # Actual update logic here
- dprint(f"[Task ID: {task_id}] Workflow modification step (placeholder). You need to implement dynamic input injection.")
-
- # --- 4. Send the workflow to ComfyUI API (/prompt endpoint) ---
- # Placeholder: Use the `requests` library to POST the `loaded_workflow_json`.
- # Example:
- # prompt_payload = {"prompt": loaded_workflow_json, "client_id": comfyui_client_id}
- # response = requests.post(f"{comfyui_server_address}/prompt", json=prompt_payload)
- # response.raise_for_status()
- # prompt_response_data = response.json()
- # submitted_prompt_id = prompt_response_data.get('prompt_id')
- submitted_prompt_id = "dummy_prompt_id_123" # Replace with actual submission
- if not submitted_prompt_id:
- print(f"[ERROR Task ID: {task_id}] Failed to submit prompt to ComfyUI or did not receive prompt_id.")
- # dprint(f"ComfyUI submission response: {prompt_response_data}") # If you had actual response
- raise ValueError("ComfyUI prompt submission failed")
- dprint(f"[Task ID: {task_id}] Workflow submitted to ComfyUI. Prompt ID: {submitted_prompt_id}")
-
- # --- 5. Poll ComfyUI API for completion status (/history or /queue) ---
- # Placeholder: Loop and poll the /history/{submitted_prompt_id} endpoint.
- # Example:
- # while True:
- # history_response = requests.get(f"{comfyui_server_address}/history/{submitted_prompt_id}")
- # history_response.raise_for_status()
- # history_data = history_response.json()
- # if submitted_prompt_id in history_data and history_data[submitted_prompt_id].get("outputs"):
- # dprint(f"[Task ID: {task_id}] ComfyUI workflow execution completed.")
- # comfyui_outputs = history_data[submitted_prompt_id]["outputs"]
- # break
- # # Check for errors in queue or history too
- # # Add timeout logic
- # time.sleep(5) # Polling interval
- comfyui_outputs = {} # Replace with actual polling and output retrieval
- dprint(f"[Task ID: {task_id}] Workflow completion polling (placeholder). Outputs (placeholder): {comfyui_outputs}")
-
- # --- 6. Identify and retrieve output files ---
- # Placeholder: Parse `comfyui_outputs` to find the generated files.
- # ComfyUI outputs often specify filename, subfolder, type relative to ComfyUI's base output directory.
- # You'll need to know ComfyUI's output directory structure or configure it.
- # Example:
- # output_video_files = []
- # for node_id_output, node_output_data in comfyui_outputs.items():
- # if 'videos' in node_output_data:
- # for video_info in node_output_data['videos']:
- # # Construct full path: ComfyUI_output_dir / video_info['subfolder'] / video_info['filename']
- # # output_video_files.append(full_path_to_video)
- # pass
- # generated_final_file_path = output_video_files[0] if output_video_files else None
-
- # --- Temporary placeholder for generated_final_file_path using COMFYUI_OUTPUT_PATH_CONFIG ---
- # This simulates finding an output. Replace with actual parsing of `comfyui_outputs`.
- # For this example, assume ComfyUI API told us about "some_video.mp4" in subfolder "temp_xyz"
- placeholder_comfy_subfolder = "temp_xyz_output_from_comfy"
- placeholder_comfy_filename = "placeholder_comfy_video.mp4"
-
- # Ensure the placeholder subfolder exists within the configured ComfyUI output path for the dummy file creation to succeed
- # In a real scenario, ComfyUI would create this subfolder.
- if COMFYUI_OUTPUT_PATH_CONFIG: # Check added to ensure it's not None
- full_placeholder_comfy_dir = COMFYUI_OUTPUT_PATH_CONFIG / placeholder_comfy_subfolder
- full_placeholder_comfy_dir.mkdir(parents=True, exist_ok=True) # Create if doesn't exist
- generated_final_file_path = full_placeholder_comfy_dir / placeholder_comfy_filename
-
- # For this placeholder, let's simulate a file being created by ComfyUI in its output folder.
- with open(generated_final_file_path, "w") as f: f.write("dummy comfyui output content from its real output dir")
- dprint(f"[Task ID: {task_id}] Output file identification (placeholder). Assumed output from ComfyUI: {generated_final_file_path}")
- else:
- # This case should ideally be caught by the COMFYUI_OUTPUT_PATH_CONFIG check at the start of this section
- # or the task should fail gracefully if it cannot determine the path.
- generated_final_file_path = None # Cannot form path if COMFYUI_OUTPUT_PATH_CONFIG is None
- dprint(f"[Task ID: {task_id}] COMFYUI_OUTPUT_PATH_CONFIG is None, cannot determine placeholder output path.")
- # --- End Temporary placeholder ---
-
- if generated_final_file_path and generated_final_file_path.exists() and generated_final_file_path.stat().st_size > 0:
- # --- 7. Move/Upload the final output file ---
- if DB_TYPE == "sqlite":
- custom_output_path_str = task_params_dict.get("output_path")
- if custom_output_path_str:
- final_video_path = Path(custom_output_path_str).expanduser().resolve()
- final_video_path.parent.mkdir(parents=True, exist_ok=True)
- else:
- output_sub_dir_val = task_params_dict.get("output_sub_dir", task_id) # Use task_id as fallback subdir
- output_sub_dir_path = Path(output_sub_dir_val)
- if output_sub_dir_path.is_absolute():
- final_output_dir = output_sub_dir_path
- else:
- final_output_dir = main_output_dir_base / output_sub_dir_path
- final_output_dir.mkdir(parents=True, exist_ok=True)
- final_video_path = final_output_dir / f"{task_id}_comfy_output{generated_final_file_path.suffix}"
+ print(f"[HEADLESS_DEBUG] Task {task_id}: CALLING WGP GENERATION")
+ print(f"[HEADLESS_DEBUG] Final video_length parameter: {ui_params.get('video_length')}")
+ print(f"[HEADLESS_DEBUG] Resolution: {ui_params.get('resolution')}")
+ print(f"[HEADLESS_DEBUG] Seed: {ui_params.get('seed')}")
+ print(f"[HEADLESS_DEBUG] Steps: {ui_params.get('num_inference_steps')}")
+ print(f"[HEADLESS_DEBUG] Model: {model_filename_for_task}")
+ dprint(f"[Task ID: {task_id}] Calling wgp_mod.generate_video with effective ui_params (first 1000 chars): {json.dumps(ui_params, default=lambda o: 'Unserializable' if isinstance(o, Image.Image) else o.__dict__ if hasattr(o, '__dict__') else str(o), indent=2)[:1000]}...")
+ wgp_mod.generate_video(
+ task=gen_task_placeholder, send_cmd=send_cmd,
+ prompt=ui_params["prompt"],
+ negative_prompt=ui_params.get("negative_prompt", ""),
+ resolution=ui_params["resolution"],
+ video_length=ui_params.get("video_length", 81),
+ seed=ui_params["seed"],
+ num_inference_steps=ui_params.get("num_inference_steps", 30),
+ guidance_scale=ui_params.get("guidance_scale", 5.0),
+ audio_guidance_scale=ui_params.get("audio_guidance_scale", 5.0),
+ flow_shift=ui_params.get("flow_shift", wgp_mod.get_default_flow(model_filename_for_task, wgp_mod.test_class_i2v(model_filename_for_task))),
+ embedded_guidance_scale=ui_params.get("embedded_guidance_scale", 6.0),
+ repeat_generation=ui_params.get("repeat_generation", 1),
+ multi_images_gen_type=ui_params.get("multi_images_gen_type", 0),
+ tea_cache_setting=tea_cache_value,
+ tea_cache_start_step_perc=ui_params.get("tea_cache_start_step_perc", 0),
+ activated_loras=ui_params.get("activated_loras", []),
+ loras_multipliers=ui_params.get("loras_multipliers", ""),
+ image_prompt_type=ui_params.get("image_prompt_type", "T"),
+ image_start=[wgp_mod.convert_image(img) for img in ui_params.get("image_start", [])],
+ image_end=[wgp_mod.convert_image(img) for img in ui_params.get("image_end", [])],
+ model_mode=ui_params.get("model_mode", 0),
+ video_source=ui_params.get("video_source", None),
+ keep_frames_video_source=ui_params.get("keep_frames_video_source", ""),
- shutil.move(str(generated_final_file_path), str(final_video_path))
- output_location_to_db = str(final_video_path.resolve())
- print(f"[Task ID: {task_id}] ComfyUI Output (SQLite) saved to: {output_location_to_db}")
+ video_prompt_type=ui_params.get("video_prompt_type", ""),
+ image_refs=ui_params.get("image_refs", None),
+ video_guide=ui_params.get("video_guide", None),
+ keep_frames_video_guide=ui_params.get("keep_frames_video_guide", ""),
+ video_mask=ui_params.get("video_mask", None),
+ audio_guide=ui_params.get("audio_guide", None),
+ sliding_window_size=ui_params.get("sliding_window_size", 81),
+ sliding_window_overlap=ui_params.get("sliding_window_overlap", 5),
+ sliding_window_overlap_noise=ui_params.get("sliding_window_overlap_noise", 20),
+ sliding_window_discard_last_frames=ui_params.get("sliding_window_discard_last_frames", 0),
+ remove_background_images_ref=ui_params.get("remove_background_images_ref", False),
+ temporal_upsampling=ui_params.get("temporal_upsampling", ""),
+ spatial_upsampling=ui_params.get("spatial_upsampling", ""),
+ RIFLEx_setting=ui_params.get("RIFLEx_setting", 0),
+ slg_switch=ui_params.get("slg_switch", 0),
+ slg_layers=ui_params.get("slg_layers", [9]),
+ slg_start_perc=ui_params.get("slg_start_perc", 10),
+ slg_end_perc=ui_params.get("slg_end_perc", 90),
+ cfg_star_switch=ui_params.get("cfg_star_switch", 0),
+ cfg_zero_step=ui_params.get("cfg_zero_step", -1),
+ prompt_enhancer=ui_params.get("prompt_enhancer", ""),
+ state=state,
+ model_filename=model_filename_for_task
+ )
+ print(f"[Task ID: {task_id}] Generation completed to temporary directory: {wgp_mod.save_path}")
+ print(f"[HEADLESS_DEBUG] Task {task_id}: WGP GENERATION COMPLETED")
+ print(f"[HEADLESS_DEBUG] Temporary output directory: {temp_output_dir}")
- elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
- encoded_file_name = urllib.parse.quote(generated_final_file_path.name)
- object_name = f"{task_id}/{encoded_file_name}"
- public_url = upload_to_supabase_storage(generated_final_file_path, object_name, SUPABASE_VIDEO_BUCKET)
- if public_url:
- output_location_to_db = public_url
- print(f"[Task ID: {task_id}] ComfyUI Output (Supabase) uploaded. URL: {public_url}")
- else:
- print(f"[WARNING Task ID: {task_id}] Supabase upload failed for ComfyUI output.")
- # generation_success remains False if upload fails
- else:
- print(f"[WARNING Task ID: {task_id}] ComfyUI output generated but DB_TYPE ({DB_TYPE}) not supported for auto-handling or Supabase client not configured.")
+ # List all files in temp directory for debugging
+ temp_dir_contents = list(Path(temp_output_dir).iterdir())
+ print(f"[HEADLESS_DEBUG] Files in temp directory: {len(temp_dir_contents)}")
+ for item in temp_dir_contents:
+ if item.is_file():
+ print(f"[HEADLESS_DEBUG] {item.name} ({item.stat().st_size} bytes)")
- generation_success = True # Set to true if output handled
- else:
- print(f"[ERROR Task ID: {task_id}] ComfyUI execution did not produce a valid output file at {generated_final_file_path} or file is empty.")
- generation_success = False
-
- except requests.exceptions.RequestException as e_req:
- print(f"[ERROR Task ID: {task_id}] ComfyUI API request failed: {e_req}")
- traceback.print_exc()
- generation_success = False
- output_location_to_db = f"ComfyUI API Error: {e_req}"
- except ValueError as e_val:
- print(f"[ERROR Task ID: {task_id}] Value error during ComfyUI task processing: {e_val}")
- traceback.print_exc()
- generation_success = False
- output_location_to_db = f"ComfyUI Processing Value Error: {e_val}"
- except Exception as e:
- print(f"[ERROR Task ID: {task_id}] Unhandled exception during ComfyUI workflow execution: {e}")
- traceback.print_exc()
- generation_success = False
- output_location_to_db = f"ComfyUI Unhandled Exception: {e}"
- finally:
- # Clean up the temporary ComfyUI output directory
- try:
- if Path(temp_comfy_output_dir).exists():
- shutil.rmtree(temp_comfy_output_dir)
- dprint(f"[Task ID: {task_id}] Cleaned up temporary ComfyUI directory: {temp_comfy_output_dir}")
- except Exception as e_clean:
- print(f"[WARNING Task ID: {task_id}] Failed to clean up temp ComfyUI directory {temp_comfy_output_dir}: {e_clean}")
-
- print(f"--- Finished ComfyUI Workflow task ID: {task_id} (Success: {generation_success}) ---")
- return generation_success, output_location_to_db
-
-# --- SM_RESTRUCTURE: New Handler Functions for Travel Tasks ---
-def _handle_travel_orchestrator_task(task_params_from_db: dict, main_output_dir_base: Path, orchestrator_task_id_str: str):
- """Handles the main 'travel_orchestrator' task.
- Its primary role is to parse the orchestration details and enqueue the FIRST 'travel_segment' task.
- """
- dprint(f"_handle_travel_orchestrator_task: Starting for {orchestrator_task_id_str}")
- generation_success = False # Represents success of *orchestration step*, not full travel
- output_location_for_orchestrator_db = None # Orchestrator itself doesn't produce a final video
-
- try:
- # The actual orchestrator payload is nested inside 'params' from the DB row
- if 'orchestrator_details' not in task_params_from_db:
- msg = f"[ERROR Task ID: {orchestrator_task_id_str}] 'orchestrator_details' not found in task_params_from_db."
- print(msg)
- return False, msg
-
- orchestrator_payload = task_params_from_db['orchestrator_details']
- dprint(f"Orchestrator payload for {orchestrator_task_id_str}: {json.dumps(orchestrator_payload, indent=2, default=str)}")
-
- run_id = orchestrator_payload.get("run_id", orchestrator_task_id_str) # Fallback to task_id if run_id missing
-
- # Ensure main_output_dir_for_run from payload is used as the base for current_run_output_dir
- # Fallback to main_output_dir_base (server's default) if not in payload, though it should be.
- base_dir_for_this_run_str = orchestrator_payload.get("main_output_dir_for_run", str(main_output_dir_base.resolve()))
- current_run_output_dir_name = f"travel_run_{run_id}"
- # Create current_run_output_dir under the directory specified by main_output_dir_for_run from the orchestrator payload
- current_run_output_dir = Path(base_dir_for_this_run_str) / current_run_output_dir_name
- current_run_output_dir.mkdir(parents=True, exist_ok=True)
- dprint(f"Orchestrator {orchestrator_task_id_str}: Base output directory for this run: {current_run_output_dir.resolve()}")
-
- # --- Enqueue the first segment task ---
- segment_index = 0 # Start with the first segment
- num_new_segments = orchestrator_payload.get("num_new_segments_to_generate", 0)
-
- if num_new_segments <= 0:
- msg = f"[WARNING Task ID: {orchestrator_task_id_str}] No new segments to generate based on orchestrator payload. Orchestration complete (vacuously)."
- print(msg)
- return True, msg
-
- first_segment_task_id = sm_generate_unique_task_id(f"travel_seg_{run_id}_{segment_index:02d}_")
-
- # Prepare VACE refs for this specific segment
- vace_refs_for_this_segment = [
- ref for ref in orchestrator_payload.get("vace_image_refs_prepared", [])
- if (ref.get("type") == "initial" and ref.get("segment_idx_for_naming") == segment_index) or \
- (ref.get("type") == "final" and ref.get("segment_idx_for_naming") == segment_index)
- ]
-
- segment_task_params = {
- "segment_task_id": first_segment_task_id,
- "orchestrator_task_id_ref": orchestrator_task_id_str, # Reference to this orchestrator task
- "orchestrator_run_id": run_id,
- "segment_index": segment_index,
- "is_first_segment": True,
- "is_last_segment": (segment_index == num_new_segments - 1),
+ generation_success = True
+ except Exception as e:
+ print(f"[ERROR] Task ID {task_id} failed during generation: {e}")
+ traceback.print_exc()
+ finally:
+ wgp_mod.save_path = original_wgp_save_path
+
+ if generation_success:
+ generated_video_files = sorted([
+ item for item in Path(temp_output_dir).iterdir()
+ if item.is_file() and item.suffix.lower() == ".mp4"
+ ])
- "full_orchestrator_payload": orchestrator_payload, # Pass the whole orchestrator payload
-
- # Per-segment details (extracted for convenience, but segment handler can also get from full_orchestrator_payload)
- "input_image_paths_resolved": orchestrator_payload["input_image_paths_resolved"], # Needed for anchor logic
- "continue_from_video_resolved_path": orchestrator_payload.get("continue_from_video_resolved_path"), # Needed for anchor logic
- "base_prompt": orchestrator_payload["base_prompts_expanded"][segment_index],
- "negative_prompt": orchestrator_payload["negative_prompts_expanded"][segment_index],
- "segment_frames_target": orchestrator_payload["segment_frames_expanded"][segment_index],
- "frame_overlap_with_next": orchestrator_payload["frame_overlap_expanded"][segment_index],
- "frame_overlap_from_previous": orchestrator_payload["frame_overlap_expanded"][0] if orchestrator_payload.get("continue_from_video_resolved_path") and segment_index == 0 else 0,
- "vace_image_refs_for_segment": vace_refs_for_this_segment,
-
- "current_run_base_output_dir": str(current_run_output_dir.resolve()), # Base for segment's specific output folder
- "parsed_resolution_wh": orchestrator_payload["parsed_resolution_wh"],
- "model_name": orchestrator_payload["model_name"],
- "seed_to_use": orchestrator_payload.get("seed_base", 12345) + segment_index, # Ensure seed_base exists or default
- "execution_engine": orchestrator_payload["execution_engine"],
- # ComfyUI specific workflow paths if engine is ComfyUI, segment handler will use these
- "comfyui_workflow_path_override": orchestrator_payload.get("original_task_args", {}).get("comfyui_workflow_path_for_travel"),
- "comfyui_server_address_override": orchestrator_payload.get("original_task_args", {}).get("comfyui_server_address_for_travel"),
-
- # Pass through common args that segment task (and its WGP/Comfy sub-task) will need
- "use_causvid_lora": orchestrator_payload.get("use_causvid_lora", False),
- "cfg_star_switch": orchestrator_payload.get("cfg_star_switch", 0),
- "cfg_zero_step": orchestrator_payload.get("cfg_zero_step", -1),
- "params_json_str_override": orchestrator_payload.get("params_json_str_override"),
- "fps_helpers": orchestrator_payload["fps_helpers"],
- "last_frame_duplication": orchestrator_payload["last_frame_duplication"],
- "fade_in_params_json_str": orchestrator_payload["fade_in_params_json_str"],
- "fade_out_params_json_str": orchestrator_payload["fade_out_params_json_str"],
- "subsequent_starting_strength_adjustment": orchestrator_payload.get("subsequent_starting_strength_adjustment", 0.0),
- "desaturate_subsequent_starting_frames": orchestrator_payload.get("desaturate_subsequent_starting_frames", 0.0),
- "adjust_brightness_subsequent_starting_frames": orchestrator_payload.get("adjust_brightness_subsequent_starting_frames", 0.0),
- "after_first_post_generation_saturation": orchestrator_payload.get("after_first_post_generation_saturation"), # For segment post-processing
- "crossfade_sharp_amt": orchestrator_payload.get("crossfade_sharp_amt", 0.3), # For stitcher
-
- "upscale_factor": orchestrator_payload.get("upscale_factor", 0.0), # For stitcher
- "upscale_model_name": orchestrator_payload.get("upscale_model_name"), # For stitcher
-
- "debug_mode_enabled": orchestrator_payload.get("debug_mode_enabled", False),
- "skip_cleanup_enabled": orchestrator_payload.get("skip_cleanup_enabled", False) # For stitcher
- }
-
- # Determine DB path from global config (SQLITE_DB_PATH or SUPABASE_CLIENT implicitly)
- db_path_for_add = SQLITE_DB_PATH if DB_TYPE == "sqlite" else None # Supabase uses client
-
- sm_add_task_to_db(
- task_payload=segment_task_params,
- db_path_str=db_path_for_add, # Pass None if Supabase, it uses global client
- task_type="travel_segment", # This is the task_type for the segment generation
- task_id_override=first_segment_task_id
- )
- msg = f"Orchestrator {orchestrator_task_id_str}: Successfully enqueued first 'travel_segment' task (ID: {first_segment_task_id})."
- print(msg)
- generation_success = True
- output_location_for_orchestrator_db = msg
-
- except Exception as e:
- msg = f"[ERROR Task ID: {orchestrator_task_id_str}] Failed during travel orchestration: {e}"
- print(msg)
- traceback.print_exc()
- generation_success = False
- output_location_for_orchestrator_db = msg
-
- return generation_success, output_location_for_orchestrator_db
-
-def _handle_travel_segment_task(wgp_mod, task_params_from_db: dict, main_output_dir_base: Path, segment_task_id_str: str):
- dprint(f"_handle_travel_segment_task: Starting for {segment_task_id_str}")
- # task_params_from_db contains what was enqueued for this specific segment,
- # including potentially 'full_orchestrator_payload'.
- segment_params = task_params_from_db
- generation_success = False # Success of the WGP/Comfy sub-task for this segment
- final_segment_video_output_path_str = None # Output of the WGP sub-task
-
- try:
- # --- 1. Initialization & Parameter Extraction ---
- orchestrator_task_id_ref = segment_params.get("orchestrator_task_id_ref")
- orchestrator_run_id = segment_params.get("orchestrator_run_id")
- segment_idx = segment_params.get("segment_index")
-
- if orchestrator_task_id_ref is None or orchestrator_run_id is None or segment_idx is None:
- msg = f"Segment task {segment_task_id_str} missing critical orchestrator refs or segment_index."
- print(f"[ERROR Task {segment_task_id_str}]: {msg}")
- return False, msg
-
- full_orchestrator_payload = segment_params.get("full_orchestrator_payload")
- if not full_orchestrator_payload:
- dprint(f"Segment {segment_idx}: full_orchestrator_payload not in direct params. Querying orchestrator task {orchestrator_task_id_ref}")
- orchestrator_task_raw_params_json = None
- if DB_TYPE == "sqlite":
- def _get_orc_params(conn):
- cursor = conn.cursor()
- cursor.execute("SELECT params FROM tasks WHERE task_id = ?", (orchestrator_task_id_ref,))
- row = cursor.fetchone()
- return row[0] if row else None
- orchestrator_task_raw_params_json = execute_sqlite_with_retry(SQLITE_DB_PATH, _get_orc_params)
- elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
- orc_resp = SUPABASE_CLIENT.table(PG_TABLE_NAME).select("params").eq("task_id", orchestrator_task_id_ref).execute()
- if orc_resp.data: orchestrator_task_raw_params_json = orc_resp.data[0]["params"]
+ print(f"[HEADLESS_DEBUG] Task {task_id}: ANALYZING GENERATED FILES")
+ print(f"[HEADLESS_DEBUG] Found {len(generated_video_files)} .mp4 files")
- if orchestrator_task_raw_params_json:
- try:
- fetched_params = json.loads(orchestrator_task_raw_params_json) if isinstance(orchestrator_task_raw_params_json, str) else orchestrator_task_raw_params_json
- full_orchestrator_payload = fetched_params.get("orchestrator_details")
- if not full_orchestrator_payload:
- raise ValueError("'orchestrator_details' key missing in fetched orchestrator task params.")
- dprint(f"Segment {segment_idx}: Successfully fetched orchestrator_details from DB.")
- except Exception as e_fetch_orc:
- msg = f"Segment {segment_idx}: Failed to fetch/parse orchestrator_details from DB for task {orchestrator_task_id_ref}: {e_fetch_orc}"
- print(f"[ERROR Task {segment_task_id_str}]: {msg}")
- return False, msg
- else:
- msg = f"Segment {segment_idx}: Could not retrieve params for orchestrator task {orchestrator_task_id_ref}. Cannot proceed."
- print(f"[ERROR Task {segment_task_id_str}]: {msg}")
- return False, msg
-
- # Now full_orchestrator_payload is guaranteed to be populated or we've exited.
- current_run_base_output_dir_str = segment_params.get("current_run_base_output_dir")
- if not current_run_base_output_dir_str: # Should be passed by orchestrator/prev segment
- current_run_base_output_dir_str = full_orchestrator_payload.get("main_output_dir_for_run", str(main_output_dir_base.resolve()))
- current_run_base_output_dir_str = str(Path(current_run_base_output_dir_str) / f"travel_run_{orchestrator_run_id}")
-
- current_run_base_output_dir = Path(current_run_base_output_dir_str)
- segment_processing_dir = current_run_base_output_dir / f"segment_{segment_idx:02d}_{segment_task_id_str[:8]}"
- segment_processing_dir.mkdir(parents=True, exist_ok=True)
- dprint(f"Segment {segment_idx} (Task {segment_task_id_str}): Processing in {segment_processing_dir.resolve()}")
-
- # --- 2. Guide Video Preparation ---
- actual_guide_video_path_for_wgp: Path | None = None
- path_to_previous_segment_video_output_for_guide: str | None = None
-
- is_first_segment = segment_params.get("is_first_segment", segment_idx == 0) # is_first_segment should be reliable
- is_first_segment_from_scratch = is_first_segment and not full_orchestrator_payload.get("continue_from_video_resolved_path")
- is_first_new_segment_after_continue = is_first_segment and full_orchestrator_payload.get("continue_from_video_resolved_path")
- is_subsequent_segment = not is_first_segment
-
- parsed_res_wh = full_orchestrator_payload["parsed_resolution_wh"]
- fps_helpers = full_orchestrator_payload["fps_helpers"]
- fade_in_duration_str = full_orchestrator_payload["fade_in_params_json_str"]
- fade_out_duration_str = full_orchestrator_payload["fade_out_params_json_str"]
- last_frame_duplication_count = full_orchestrator_payload["last_frame_duplication"]
-
- # Define gray_frame_bgr here for use in subsequent segment strength adjustment
- gray_frame_bgr = sm_create_color_frame(parsed_res_wh, (128, 128, 128))
-
- try: # Parsing fade params
- fade_in_p = json.loads(fade_in_duration_str)
- fi_low, fi_high, fi_curve, fi_factor = float(fade_in_p.get("low_point",0)), float(fade_in_p.get("high_point",1)), str(fade_in_p.get("curve_type","ease_in_out")), float(fade_in_p.get("duration_factor",0))
- except Exception as e_fade_in:
- fi_low, fi_high, fi_curve, fi_factor = 0.0,1.0,"ease_in_out",0.0
- dprint(f"Seg {segment_idx} Warn: Using default fade-in params due to parse error on '{fade_in_duration_str}': {e_fade_in}")
- try:
- fade_out_p = json.loads(fade_out_duration_str)
- fo_low, fo_high, fo_curve, fo_factor = float(fade_out_p.get("low_point",0)), float(fade_out_p.get("high_point",1)), str(fade_out_p.get("curve_type","ease_in_out")), float(fade_out_p.get("duration_factor",0))
- except Exception as e_fade_out:
- fo_low, fo_high, fo_curve, fo_factor = 0.0,1.0,"ease_in_out",0.0
- dprint(f"Seg {segment_idx} Warn: Using default fade-out params due to parse error on '{fade_out_duration_str}': {e_fade_out}")
-
- if is_first_new_segment_after_continue:
- path_to_previous_segment_video_output_for_guide = full_orchestrator_payload.get("continue_from_video_resolved_path")
- if not path_to_previous_segment_video_output_for_guide or not Path(path_to_previous_segment_video_output_for_guide).exists():
- msg = f"Seg {segment_idx}: Continue video path {path_to_previous_segment_video_output_for_guide} invalid."
- print(f"[ERROR Task {segment_task_id_str}]: {msg}"); return False, msg
- elif is_subsequent_segment:
- path_to_previous_segment_video_output_for_guide = segment_params.get("path_to_previous_segment_output")
- if not path_to_previous_segment_video_output_for_guide:
- prev_seg_task_id_to_find = get_previous_segment_task_id(orchestrator_run_id, segment_idx - 1)
- if prev_seg_task_id_to_find:
- path_to_previous_segment_video_output_for_guide = get_task_output_location_from_db(prev_seg_task_id_to_find)
- if not path_to_previous_segment_video_output_for_guide or not Path(path_to_previous_segment_video_output_for_guide).exists():
- msg = f"Seg {segment_idx}: Prev segment output for guide invalid/not found. Expected from prev task output. Path: {path_to_previous_segment_video_output_for_guide}"
- print(f"[ERROR Task {segment_task_id_str}]: {msg}"); return False, msg
-
- try: # Guide Video Creation Block
- guide_video_base_name = f"s{segment_idx}_guide_vid"
- # Ensure guide video is saved in the current segment's processing directory
- actual_guide_video_path_for_wgp = sm_get_unique_target_path(segment_processing_dir, guide_video_base_name, ".mp4")
+ # Analyze each video file found
+ for i, video_file in enumerate(generated_video_files):
+ try:
+ from source.common_utils import get_video_frame_count_and_fps
+ frame_count, fps = get_video_frame_count_and_fps(str(video_file))
+ file_size = video_file.stat().st_size
+ duration = frame_count / fps if fps and fps > 0 else 0
+ print(f"[HEADLESS_DEBUG] Video {i}: {video_file.name}")
+ print(f"[HEADLESS_DEBUG] Frames: {frame_count}")
+ print(f"[HEADLESS_DEBUG] FPS: {fps}")
+ print(f"[HEADLESS_DEBUG] Duration: {duration:.2f}s")
+ print(f"[HEADLESS_DEBUG] Size: {file_size / (1024*1024):.2f} MB")
+ print(f"[HEADLESS_DEBUG] Expected frames: {frame_num_for_wgp}")
+ if frame_count != frame_num_for_wgp:
+ print(f"[HEADLESS_DEBUG] ⚠️ FRAME COUNT MISMATCH! Expected {frame_num_for_wgp}, got {frame_count}")
+ except Exception as e_analysis:
+ print(f"[HEADLESS_DEBUG] ERROR analyzing {video_file.name}: {e_analysis}")
- # segment_frames_target and frame_overlap_from_previous should be in segment_params, passed by orchestrator/prev segment
- base_duration_new_content_for_guide = segment_params.get("segment_frames_target", full_orchestrator_payload["segment_frames_expanded"][segment_idx])
- overlap_connecting_to_previous_for_guide = segment_params.get("frame_overlap_from_previous", 0) # Default to 0 if not set
- guide_video_total_frames = base_duration_new_content_for_guide + overlap_connecting_to_previous_for_guide
-
- if guide_video_total_frames <= 0:
- dprint(f"Task {segment_task_id_str}: Guide video frames {guide_video_total_frames}. No guide will be created."); actual_guide_video_path_for_wgp = None
+ generated_video_file = None
+ if not generated_video_files:
+ print(f"[WARNING Task ID: {task_id}] Generation reported success, but no .mp4 files found in {temp_output_dir}")
+ generation_success = False
+ elif len(generated_video_files) == 1:
+ generated_video_file = generated_video_files[0]
+ dprint(f"[Task ID: {task_id}] Found a single generated video: {generated_video_file}")
else:
- frames_for_guide_list = [sm_create_color_frame(parsed_res_wh, (128,128,128)).copy() for _ in range(guide_video_total_frames)]
- input_images_resolved = full_orchestrator_payload["input_image_paths_resolved"]
- end_anchor_img_path_str: str
- if full_orchestrator_payload.get("continue_from_video_resolved_path"): # Number of input images matches number of new segments
- if segment_idx < len(input_images_resolved):
- end_anchor_img_path_str = input_images_resolved[segment_idx]
- else:
- raise ValueError(f"Seg {segment_idx}: End anchor index {segment_idx} out of bounds for input_images ({len(input_images_resolved)}) with continue_from_video.")
- else: # Not continuing from video, so number of input images is num_segments + 1
- if (segment_idx + 1) < len(input_images_resolved):
- end_anchor_img_path_str = input_images_resolved[segment_idx + 1]
- else:
- raise ValueError(f"Seg {segment_idx}: End anchor index {segment_idx+1} out of bounds for input_images ({len(input_images_resolved)}) when not continuing from video.")
+ dprint(f"[Task ID: {task_id}] Found {len(generated_video_files)} video segments to stitch: {generated_video_files}")
+ stitched_video_path = Path(temp_output_dir) / f"{task_id}_stitched.mp4"
- end_anchor_frame_np = sm_image_to_frame(end_anchor_img_path_str, parsed_res_wh)
- if end_anchor_frame_np is None: raise ValueError(f"Failed to load end anchor image: {end_anchor_img_path_str}")
- num_end_anchor_duplicates = last_frame_duplication_count + 1
- start_anchor_frame_np = None
-
- if is_first_segment_from_scratch:
- start_anchor_img_path_str = input_images_resolved[0]
- start_anchor_frame_np = sm_image_to_frame(start_anchor_img_path_str, parsed_res_wh)
- if start_anchor_frame_np is None: raise ValueError(f"Failed to load start anchor: {start_anchor_img_path_str}")
- if frames_for_guide_list: frames_for_guide_list[0] = start_anchor_frame_np.copy()
- # ... (Ported fade logic for start_anchor_frame_np from original file, using fo_low, fo_high, fo_curve, fo_factor) ...
- pot_max_idx_start_fade = guide_video_total_frames - num_end_anchor_duplicates - 1
- avail_frames_start_fade = max(0, pot_max_idx_start_fade)
- num_start_fade_steps = int(avail_frames_start_fade * fo_factor)
- if num_start_fade_steps > 0:
- actual_start_fade_end_idx = min(num_start_fade_steps -1 , pot_max_idx_start_fade -1)
- easing_fn_out = sm_get_easing_function(fo_curve)
- for k_fo in range(num_start_fade_steps):
- idx_in_guide = 1 + k_fo
- if idx_in_guide > actual_start_fade_end_idx +1 or idx_in_guide >= guide_video_total_frames: break
- alpha_lin = 1.0 - ((k_fo + 1) / float(num_start_fade_steps))
- e_alpha = fo_low + (fo_high - fo_low) * easing_fn_out(alpha_lin)
- e_alpha = np.clip(e_alpha, 0.0, 1.0)
- frames_for_guide_list[idx_in_guide] = cv2.addWeighted(frames_for_guide_list[idx_in_guide].astype(np.float32), 1.0 - e_alpha, start_anchor_frame_np.astype(np.float32), e_alpha, 0).astype(np.uint8)
-
- # ... (Ported fade logic for end_anchor_frame_np from original file, using fi_low, fi_high, fi_curve, fi_factor) ...
- min_idx_end_fade = 1
- max_idx_end_fade = guide_video_total_frames - num_end_anchor_duplicates - 1
- avail_frames_end_fade = max(0, max_idx_end_fade - min_idx_end_fade + 1)
- num_end_fade_steps = int(avail_frames_end_fade * fi_factor)
- if num_end_fade_steps > 0:
- actual_end_fade_start_idx = max(min_idx_end_fade, max_idx_end_fade - num_end_fade_steps + 1)
- easing_fn_in = sm_get_easing_function(fi_curve)
- for k_fi in range(num_end_fade_steps):
- idx_in_guide = actual_end_fade_start_idx + k_fi
- if idx_in_guide > max_idx_end_fade or idx_in_guide >= guide_video_total_frames: break
- alpha_lin = (k_fi + 1) / float(num_end_fade_steps)
- e_alpha = fi_low + (fi_high - fi_low) * easing_fn_in(alpha_lin)
- e_alpha = np.clip(e_alpha, 0.0, 1.0)
- base_f = frames_for_guide_list[idx_in_guide]
- frames_for_guide_list[idx_in_guide] = cv2.addWeighted(base_f.astype(np.float32), 1.0 - e_alpha, end_anchor_frame_np.astype(np.float32), e_alpha, 0).astype(np.uint8)
- elif fi_factor > 0 and avail_frames_end_fade > 0:
- for k_fill in range(min_idx_end_fade, max_idx_end_fade + 1):
- if k_fill < guide_video_total_frames: frames_for_guide_list[k_fill] = end_anchor_frame_np.copy()
-
- elif path_to_previous_segment_video_output_for_guide: # Continued or Subsequent
- prev_vid_total_frames, _ = sm_get_video_frame_count_and_fps(path_to_previous_segment_video_output_for_guide)
- if prev_vid_total_frames is None: raise ValueError("Could not get frame count of previous video for guide.")
- actual_overlap_to_use = min(overlap_connecting_to_previous_for_guide, prev_vid_total_frames)
- start_extraction_idx = max(0, prev_vid_total_frames - actual_overlap_to_use)
- overlap_frames_raw = sm_extract_frames_from_video(path_to_previous_segment_video_output_for_guide, start_extraction_idx, actual_overlap_to_use)
- frames_read_for_overlap = 0
- for k, frame_fp in enumerate(overlap_frames_raw): # frame_fp is frame_from_prev
- if k >= guide_video_total_frames: break
- if frame_fp.shape[1]!=parsed_res_wh[0] or frame_fp.shape[0]!=parsed_res_wh[1]: frame_fp = cv2.resize(frame_fp, parsed_res_wh, interpolation=cv2.INTER_AREA)
- frames_for_guide_list[k] = frame_fp.copy()
- frames_read_for_overlap +=1
-
- # ... (Ported strength/desat/brightness logic from original, using segment_params/full_orchestrator_payload for settings) ...
- strength_adj = full_orchestrator_payload.get("subsequent_starting_strength_adjustment", 0.0)
- desat_factor = full_orchestrator_payload.get("desaturate_subsequent_starting_frames", 0.0)
- bright_adj = full_orchestrator_payload.get("adjust_brightness_subsequent_starting_frames", 0.0)
- if frames_read_for_overlap > 0:
- if fo_factor > 0.0:
- num_init_fade_steps = min(int(frames_read_for_overlap * fo_factor), frames_read_for_overlap)
- easing_fn_fo_ol = sm_get_easing_function(fo_curve)
- for k_fo_ol in range(num_init_fade_steps):
- alpha_l = 1.0 - ((k_fo_ol + 1) / float(num_init_fade_steps))
- eff_s = fo_low + (fo_high - fo_low) * easing_fn_fo_ol(alpha_l)
- eff_s += strength_adj
- eff_s = np.clip(eff_s,0,1)
- base_f=frames_for_guide_list[k_fo_ol]
- frames_for_guide_list[k_fo_ol] = cv2.addWeighted(gray_frame_bgr.astype(np.float32),1-eff_s,base_f.astype(np.float32),eff_s,0).astype(np.uint8)
- if desat_factor > 0:
- g=cv2.cvtColor(frames_for_guide_list[k_fo_ol],cv2.COLOR_BGR2GRAY)
- gb=cv2.cvtColor(g,cv2.COLOR_GRAY2BGR)
- frames_for_guide_list[k_fo_ol]=cv2.addWeighted(frames_for_guide_list[k_fo_ol],1-desat_factor,gb,desat_factor,0)
- if bright_adj!=0:
- frames_for_guide_list[k_fo_ol]=sm_adjust_frame_brightness(frames_for_guide_list[k_fo_ol],bright_adj)
- else:
- eff_s=fo_high+strength_adj; eff_s=np.clip(eff_s,0,1)
- if abs(eff_s-1.0)>1e-5 or desat_factor>0 or bright_adj!=0:
- for k_s_ol in range(frames_read_for_overlap):
- base_f=frames_for_guide_list[k_s_ol];frames_for_guide_list[k_s_ol]=cv2.addWeighted(gray_frame_bgr.astype(np.float32),1-eff_s,base_f.astype(np.float32),eff_s,0).astype(np.uint8)
- if desat_factor>0: g=cv2.cvtColor(frames_for_guide_list[k_s_ol],cv2.COLOR_BGR2GRAY);gb=cv2.cvtColor(g,cv2.COLOR_GRAY2BGR);frames_for_guide_list[k_s_ol]=cv2.addWeighted(frames_for_guide_list[k_s_ol],1-desat_factor,gb,desat_factor,0)
- if bright_adj!=0: frames_for_guide_list[k_s_ol]=sm_adjust_frame_brightness(frames_for_guide_list[k_s_ol],bright_adj)
- # ... (Ported fade-in logic for end_anchor_frame_np for subsequent segments) ...
- min_idx_efs = frames_read_for_overlap; max_idx_efs = guide_video_total_frames - num_end_anchor_duplicates - 1
- avail_f_efs = max(0, max_idx_efs - min_idx_efs + 1); num_efs_steps = int(avail_f_efs * fi_factor)
- if num_efs_steps > 0:
- actual_efs_start_idx = max(min_idx_efs, max_idx_efs - num_efs_steps + 1)
- easing_fn_in_s = sm_get_easing_function(fi_curve)
- for k_fi_s in range(num_efs_steps):
- idx = actual_efs_start_idx+k_fi_s
- if idx > max_idx_efs or idx >= guide_video_total_frames: break
- if idx < min_idx_efs: continue
- alpha_l=(k_fi_s+1)/float(num_efs_steps);e_alpha=fi_low+(fi_high-fi_low)*easing_fn_in_s(alpha_l);e_alpha=np.clip(e_alpha,0,1)
- base_f=frames_for_guide_list[idx];frames_for_guide_list[idx]=cv2.addWeighted(base_f.astype(np.float32),1-e_alpha,end_anchor_frame_np.astype(np.float32),e_alpha,0).astype(np.uint8)
- elif fi_factor > 0 and avail_f_efs > 0:
- for k_fill in range(min_idx_efs, max_idx_efs + 1):
- if k_fill < guide_video_total_frames: frames_for_guide_list[k_fill] = end_anchor_frame_np.copy()
+ if 'sm_stitch_videos_ffmpeg' not in globals() and 'sm_stitch_videos_ffmpeg' not in locals():
+ try:
+ from source.common_utils import stitch_videos_ffmpeg as sm_stitch_videos_ffmpeg
+ except ImportError:
+ print(f"[CRITICAL ERROR Task ID: {task_id}] Failed to import 'stitch_videos_ffmpeg'. Cannot proceed with stitching.")
+ generation_success = False
- # Duplication & final first frame for all types
- if guide_video_total_frames > 0:
- for k_dup in range(min(num_end_anchor_duplicates, guide_video_total_frames)):
- idx_s = guide_video_total_frames - 1 - k_dup
- if idx_s >= 0:
- frames_for_guide_list[idx_s] = end_anchor_frame_np.copy()
- else:
- break # break the loop if idx_s is out of bounds
- if is_first_segment_from_scratch and guide_video_total_frames > 0 and start_anchor_frame_np is not None:
- frames_for_guide_list[0] = start_anchor_frame_np.copy()
-
- if frames_for_guide_list:
- guide_video_file_path = sm_create_video_from_frames_list(frames_for_guide_list, actual_guide_video_path_for_wgp, fps_helpers, parsed_res_wh)
- if guide_video_file_path and guide_video_file_path.exists(): actual_guide_video_path_for_wgp = guide_video_file_path
- else: print(f"ERROR: Failed to create guide video."); actual_guide_video_path_for_wgp = None
- else: actual_guide_video_path_for_wgp = None; dprint("No frames for guide video.")
- except Exception as e_guide:
- print(f"ERROR Task {segment_task_id_str} guide prep: {e_guide}"); traceback.print_exc(); actual_guide_video_path_for_wgp = None
-
- # --- Enqueue WGP Generation as a Sub-Task ---
- if actual_guide_video_path_for_wgp is None and not (is_first_segment_from_scratch or path_to_previous_segment_video_output_for_guide):
- # If guide creation failed AND it was essential (not first segment that could run guideless, or no prev video for subsequent)
- print(f"[ERROR Task {segment_task_id_str}]: Essential guide video failed or not possible. Cannot proceed with WGP sub-task.")
- return False, "Guide video creation failed for segment that requires it."
-
- base_duration_wgp = segment_params["segment_frames_target"]
- overlap_from_previous_wgp = segment_params["frame_overlap_from_previous"]
- current_segment_total_frames_unquantized_wgp = base_duration_wgp + overlap_from_previous_wgp
- final_frames_for_wgp_generation = current_segment_total_frames_unquantized_wgp
- current_wgp_engine = full_orchestrator_payload["execution_engine"]
-
- if current_wgp_engine == "wgp": # Quantization specific to wgp engine
- is_ltxv_m = "ltxv" in full_orchestrator_payload["model_name"].lower()
- latent_s = 8 if is_ltxv_m else 4
- quantized_wgp_f = (current_segment_total_frames_unquantized_wgp // latent_s) * latent_s + 1
- if quantized_wgp_f != current_segment_total_frames_unquantized_wgp: dprint(f"Quantizing WGP: {current_segment_total_frames_unquantized_wgp} to {quantized_wgp_f}")
- final_frames_for_wgp_generation = quantized_wgp_f
-
- wgp_sub_task_id = sm_generate_unique_task_id(f"wgp_sub_{segment_task_id_str[:10]}_")
- wgp_payload = {
- "task_id": wgp_sub_task_id,
- "model": full_orchestrator_payload["model_name"],
- "prompt": segment_params["base_prompt"],
- "negative_prompt": segment_params["negative_prompt"],
- "resolution": f"{parsed_res_wh[0]}x{parsed_res_wh[1]}",
- "frames": final_frames_for_wgp_generation,
- "seed": segment_params["seed_to_use"],
- "output_sub_dir": str(segment_processing_dir.relative_to(main_output_dir_base)),
- "video_guide_path": str(actual_guide_video_path_for_wgp.resolve()) if actual_guide_video_path_for_wgp else None,
- "use_causvid_lora": full_orchestrator_payload.get("use_causvid_lora", False),
- "cfg_star_switch": full_orchestrator_payload.get("cfg_star_switch", 0),
- "cfg_zero_step": full_orchestrator_payload.get("cfg_zero_step", -1),
- "image_refs_paths": [ref["processed_path"] for ref in segment_params.get("vace_image_refs_for_segment", []) if ref.get("processed_path")],
- # ComfyUI specific if engine is comfyui
- "comfyui_workflow_path": full_orchestrator_payload.get("comfyui_workflow_path_override", "placeholder_comfy_workflow_for_travel.json"),
- "comfyui_server_address": full_orchestrator_payload.get("comfyui_server_address_override", "http://127.0.0.1:8188"),
- }
- if full_orchestrator_payload.get("params_json_str_override"):
- try:
- additional_p = json.loads(full_orchestrator_payload["params_json_str_override"])
- additional_p.pop("frames", None); additional_p.pop("video_length", None)
- wgp_payload.update(additional_p)
- except Exception as e_json: dprint(f"Error merging override params: {e_json}")
-
- # If ComfyUI engine, structure inputs for it, especially image paths
- if current_wgp_engine == "comfyui":
- comfy_inputs = {
- "guide_video_path": wgp_payload["video_guide_path"],
- "prompt": wgp_payload["prompt"],
- "negative_prompt": wgp_payload["negative_prompt"],
- "seed": wgp_payload["seed"],
- "frames": wgp_payload["frames"],
- # Start/End images for ComfyUI. Need to determine correct ones based on segment_idx.
- }
- # Determine start_image_path for ComfyUI
- if is_first_segment_from_scratch:
- comfy_inputs["start_image_path"] = full_orchestrator_payload["input_image_paths_resolved"][0]
- elif path_to_previous_segment_video_output_for_guide: # Continued or subsequent
- # Extract first frame of previous output as start_image for ComfyUI
- # This needs a helper: extract_specific_frame_ffmpeg from common_utils
- # from .common_utils import extract_specific_frame_ffmpeg as sm_extract_specific_frame_ffmpeg
- # temp_start_frame_img = sm_get_unique_target_path(segment_processing_dir, f"s{segment_idx}_comfy_start_frame", ".png")
- # prev_vid_fps = sm_get_video_frame_count_and_fps(path_to_previous_segment_video_output_for_guide)[1] or fps_helpers
- # if sm_extract_specific_frame_ffmpeg(path_to_previous_segment_video_output_for_guide, 0, temp_start_frame_img, prev_vid_fps):
- # comfy_inputs["start_image_path"] = str(temp_start_frame_img.resolve())
- # else: dprint("Failed to extract start frame for ComfyUI from previous video.")
- # For now, ComfyUI workflow will need to handle getting start frame from prev video if path given, or use first guide frame.
- comfy_inputs["previous_video_path_for_start_frame"] = path_to_previous_segment_video_output_for_guide
-
- comfy_inputs["end_image_path"] = end_anchor_img_path_str # Determined earlier for guide
- wgp_payload["comfyui_inputs"] = comfy_inputs
-
- dprint(f"Seg {segment_idx}: Enqueuing {current_wgp_engine} sub-task {wgp_sub_task_id} with payload: {json.dumps(wgp_payload, default=str)}")
- db_path_for_wgp_add = SQLITE_DB_PATH if DB_TYPE == "sqlite" else None
- sm_add_task_to_db(wgp_payload, db_path_for_wgp_add, current_wgp_engine, wgp_sub_task_id)
-
- print(f"Seg {segment_idx}: Waiting for {current_wgp_engine} sub-task {wgp_sub_task_id}...")
- from Wan2GP.sm_functions.common_utils import poll_task_status as sm_poll_status # Re-import to ensure correct scope
- raw_wgp_output_video_loc = sm_poll_status(wgp_sub_task_id, db_path_for_wgp_add,
- full_orchestrator_payload.get("poll_interval",15),
- full_orchestrator_payload.get("poll_timeout", 30*60))
-
- if raw_wgp_output_video_loc and Path(raw_wgp_output_video_loc).exists():
- dprint(f"Seg {segment_idx}: WGP sub-task {wgp_sub_task_id} OK. Output: {raw_wgp_output_video_loc}")
- raw_out_path = Path(raw_wgp_output_video_loc)
- final_segment_video_output_path_str = str(raw_out_path.resolve())
- # Ensure output is in segment_processing_dir (it should be due to output_sub_dir in wgp_payload)
- if raw_out_path.parent.resolve() != segment_processing_dir.resolve():
- dprint(f"WGP output {raw_out_path} not in seg dir {segment_processing_dir}. This is unexpected.")
- # Copy if somehow it ended up elsewhere, though output_sub_dir should handle this.
- # For safety, let's ensure it's where we expect it for saturation and next step.
- final_dest_in_seg_dir = segment_processing_dir / raw_out_path.name
- try: shutil.copy2(str(raw_out_path), str(final_dest_in_seg_dir)); final_segment_video_output_path_str = str(final_dest_in_seg_dir.resolve())
- except Exception as e_final_copy: print(f"WARN: Failed to copy {raw_out_path} to {final_dest_in_seg_dir}: {e_final_copy}")
- generation_success = True
- else:
- print(f"ERROR Task {segment_task_id_str}: WGP sub-task {wgp_sub_task_id} failed/timed out/output missing.")
- generation_success = False
-
- # --- Post-generation Saturation ---
- if generation_success and final_segment_video_output_path_str and (is_subsequent_segment or is_first_new_segment_after_continue):
- sat_level = full_orchestrator_payload.get("after_first_post_generation_saturation")
- if sat_level is not None:
- dprint(f"Seg {segment_idx}: Applying post-gen saturation {sat_level}")
- sat_out_base = f"s{segment_idx}_final_sat_{sat_level:.2f}"
- sat_out_path = sm_get_unique_target_path(segment_processing_dir, sat_out_base, Path(final_segment_video_output_path_str).suffix)
- if sm_apply_saturation_to_video_ffmpeg(final_segment_video_output_path_str, sat_out_path, sat_level):
- final_segment_video_output_path_str = str(sat_out_path.resolve())
- else: dprint(f"Seg {segment_idx}: Failed post-gen saturation.")
-
- # --- Enqueue Next Segment OR Stitch Task ---
- if generation_success and final_segment_video_output_path_str:
- db_path_next_add = SQLITE_DB_PATH if DB_TYPE == "sqlite" else None
- if not segment_params["is_last_segment"]:
- next_seg_idx = segment_idx + 1
- next_seg_task_id = sm_generate_unique_task_id(f"travel_seg_{orchestrator_run_id}_{next_seg_idx:02d}_")
- next_seg_payload = { # Copying relevant fields, ensuring full_orchestrator_payload is passed
- "segment_task_id": next_seg_task_id, "orchestrator_task_id_ref": orchestrator_task_id_ref,
- "orchestrator_run_id": orchestrator_run_id, "segment_index": next_seg_idx,
- "is_first_segment": False, "is_last_segment": (next_seg_idx == full_orchestrator_payload["num_new_segments_to_generate"] - 1),
- "path_to_previous_segment_output": final_segment_video_output_path_str,
- "full_orchestrator_payload": full_orchestrator_payload, # CRITICAL: Pass full orchestrator payload
- # Fields specific to this segment, derived from full_orchestrator_payload by next segment handler
- "segment_frames_target": full_orchestrator_payload["segment_frames_expanded"][next_seg_idx],
- "frame_overlap_from_previous": full_orchestrator_payload["frame_overlap_expanded"][segment_idx],
- "frame_overlap_with_next": full_orchestrator_payload["frame_overlap_expanded"][next_seg_idx],
- "base_prompt": full_orchestrator_payload["base_prompts_expanded"][next_seg_idx],
- "negative_prompt": full_orchestrator_payload["negative_prompts_expanded"][next_seg_idx],
- "seed_to_use": full_orchestrator_payload["seed_base"] + next_seg_idx,
- "current_run_base_output_dir": str(current_run_base_output_dir.resolve()),
- "parsed_resolution_wh": full_orchestrator_payload["parsed_resolution_wh"],
- "model_name": full_orchestrator_payload["model_name"],
- "execution_engine": full_orchestrator_payload["execution_engine"],
- # ... (other necessary params from full_orchestrator_payload for next segment)...
- "vace_image_refs_for_segment": [ ref for ref in full_orchestrator_payload.get("vace_image_refs_prepared", []) if (ref.get("type") == "initial" and ref.get("segment_idx_for_naming", -1) == next_seg_idx) or (ref.get("type") == "final" and ref.get("segment_idx_for_naming", -1) == next_seg_idx)],
- "use_causvid_lora": full_orchestrator_payload["use_causvid_lora"],
- "cfg_star_switch": full_orchestrator_payload["cfg_star_switch"],
- "cfg_zero_step": full_orchestrator_payload["cfg_zero_step"],
- "params_json_str_override": full_orchestrator_payload.get("params_json_str_override"),
- "fps_helpers": full_orchestrator_payload["fps_helpers"],
- "last_frame_duplication": full_orchestrator_payload["last_frame_duplication"],
- "fade_in_params_json_str": full_orchestrator_payload["fade_in_params_json_str"],
- "fade_out_params_json_str": full_orchestrator_payload["fade_out_params_json_str"],
- "subsequent_starting_strength_adjustment": full_orchestrator_payload.get("subsequent_starting_strength_adjustment", 0.0),
- "desaturate_subsequent_starting_frames": full_orchestrator_payload.get("desaturate_subsequent_starting_frames", 0.0),
- "adjust_brightness_subsequent_starting_frames": full_orchestrator_payload.get("adjust_brightness_subsequent_starting_frames", 0.0),
- "after_first_post_generation_saturation": full_orchestrator_payload.get("after_first_post_generation_saturation"),
- "debug_mode_enabled": full_orchestrator_payload.get("debug_mode_enabled", False)
- }
- sm_add_task_to_db(next_seg_payload, db_path_next_add, "travel_segment", next_seg_task_id)
- dprint(f"Seg {segment_idx}: Enqueued next seg ({next_seg_idx}) task {next_seg_task_id}.")
- elif segment_params["is_last_segment"]:
- stitch_task_id = sm_generate_unique_task_id(f"travel_stitch_{orchestrator_run_id}_")
- stitch_payload = { # Payloads for stitch task
- "stitch_task_id": stitch_task_id, "orchestrator_task_id_ref": orchestrator_task_id_ref,
- "orchestrator_run_id": orchestrator_run_id,
- "num_total_segments_generated": full_orchestrator_payload["num_new_segments_to_generate"],
- "current_run_base_output_dir": str(current_run_base_output_dir.resolve()),
- "frame_overlap_settings_expanded": full_orchestrator_payload["frame_overlap_expanded"],
- "crossfade_sharp_amt": full_orchestrator_payload.get("crossfade_sharp_amt", 0.3),
- "parsed_resolution_wh": full_orchestrator_payload["parsed_resolution_wh"],
- "fps_final_video": full_orchestrator_payload["fps_helpers"],
- "upscale_factor": full_orchestrator_payload.get("upscale_factor", 0.0),
- "upscale_model_name": full_orchestrator_payload.get("upscale_model_name"),
- "seed_for_upscale": full_orchestrator_payload["seed_base"] + 5000,
- "debug_mode_enabled": full_orchestrator_payload.get("debug_mode_enabled", False),
- "skip_cleanup_enabled": full_orchestrator_payload.get("skip_cleanup_enabled", False),
- "initial_continued_video_path": full_orchestrator_payload.get("continue_from_video_resolved_path")
- }
- sm_add_task_to_db(stitch_payload, db_path_next_add, "travel_stitch", stitch_task_id)
- dprint(f"Seg {segment_idx}: Last. Enqueued stitch task {stitch_task_id}.")
- else:
- dprint(f"Seg {segment_idx}: WGP sub-task failed/skipped. No next task enqueued.")
-
- # The return for _handle_travel_segment_task is about THIS segment's WGP sub-task outcome
- return generation_success, final_segment_video_output_path_str
-
- except Exception as e:
- print(f"ERROR Task {segment_task_id_str}: Unexpected error during segment processing: {e}")
- traceback.print_exc()
- return False, f"Unexpected error: {str(e)[:200]}"
-
-def _handle_travel_stitch_task(task_params_from_db: dict, main_output_dir_base: Path, stitch_task_id_str: str):
- """Handles the final 'travel_stitch' task."""
- dprint(f"_handle_travel_stitch_task: Starting for {stitch_task_id_str}")
- stitch_params = task_params_from_db # This now contains full_orchestrator_payload
- stitch_success = False
- final_video_location_for_db = None
-
- try:
- # --- 1. Initialization & Parameter Extraction ---
- orchestrator_task_id_ref = stitch_params.get("orchestrator_task_id_ref")
- orchestrator_run_id = stitch_params.get("orchestrator_run_id")
- full_orchestrator_payload = stitch_params.get("full_orchestrator_payload")
-
- if not all([orchestrator_task_id_ref, orchestrator_run_id, full_orchestrator_payload]):
- msg = f"Stitch task {stitch_task_id_str} missing critical orchestrator refs or full_orchestrator_payload."
- print(f"[ERROR Task {stitch_task_id_str}]: {msg}")
- return False, msg
-
- current_run_base_output_dir_str = stitch_params.get("current_run_base_output_dir",
- full_orchestrator_payload.get("main_output_dir_for_run", str(main_output_dir_base.resolve())))
- current_run_base_output_dir = Path(current_run_base_output_dir_str)
- # If current_run_base_output_dir was the generic one, ensure it includes the run_id subfolder.
- if not str(current_run_base_output_dir.name).endswith(orchestrator_run_id):
- current_run_base_output_dir = current_run_base_output_dir / f"travel_run_{orchestrator_run_id}"
-
- stitch_processing_dir = current_run_base_output_dir / f"stitch_final_output_{stitch_task_id_str[:8]}"
- stitch_processing_dir.mkdir(parents=True, exist_ok=True)
- dprint(f"Stitch Task {stitch_task_id_str}: Processing in {stitch_processing_dir.resolve()}")
-
- num_expected_new_segments = full_orchestrator_payload["num_new_segments_to_generate"]
- parsed_res_wh = full_orchestrator_payload["parsed_resolution_wh"]
- final_fps = full_orchestrator_payload["fps_helpers"]
- expanded_frame_overlaps = full_orchestrator_payload["frame_overlap_expanded"]
- crossfade_sharp_amt = full_orchestrator_payload.get("crossfade_sharp_amt", 0.3)
- initial_continued_video_path_str = full_orchestrator_payload.get("continue_from_video_resolved_path")
-
- # --- 2. Collect Paths to All Segment Videos ---
- segment_video_paths_for_stitch = []
- if initial_continued_video_path_str and Path(initial_continued_video_path_str).exists():
- dprint(f"Stitch: Prepending initial continued video: {initial_continued_video_path_str}")
- segment_video_paths_for_stitch.append(str(Path(initial_continued_video_path_str).resolve()))
-
- # Query DB for all completed travel_segment tasks for this run_id
- completed_segment_outputs_from_db = []
- if DB_TYPE == "sqlite":
- def _get_sqlite_segments_for_stitch(conn):
- cursor = conn.cursor()
- # Note: JSON extract for integer might need CAST in some SQL versions if directly comparing.
- # Sorting by segment_index is critical.
- sql_query = f"""SELECT params, output_location FROM tasks
- WHERE json_extract(params, '$.orchestrator_run_id') = ?
- AND task_type = 'travel_segment' AND status = ?
- ORDER BY CAST(json_extract(params, '$.segment_index') AS INTEGER) ASC"""
- cursor.execute(sql_query, (orchestrator_run_id, STATUS_COMPLETE))
- rows = cursor.fetchall()
- parsed_rows = []
- for params_json, out_loc in rows:
- try: params_dict = json.loads(params_json); parsed_rows.append((params_dict.get("segment_index"), out_loc))
- except: dprint(f"Stitch: Error parsing params for a segment: {params_json}")
- return parsed_rows
- completed_segment_outputs_from_db = execute_sqlite_with_retry(SQLITE_DB_PATH, _get_sqlite_segments_for_stitch)
-
- elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
- try:
- # Prefer RPC if available and correctly sorts by segment_index (integer)
- # Example: func_get_completed_travel_segments(p_run_id TEXT)
- # Ensure RPC sorts by CAST(params->>'segment_index' AS INTEGER)
- rpc_response = SUPABASE_CLIENT.rpc("func_get_completed_travel_segments", {"p_run_id": orchestrator_run_id}).execute()
- if rpc_response.data:
- for item in rpc_response.data: # Assuming RPC returns list of dicts like {"segment_index": int, "output_location": str}
- completed_segment_outputs_from_db.append((item.get("segment_index"), item.get("output_location")))
- elif rpc_response.error:
- dprint(f"Stitch Supabase: Error from RPC func_get_completed_travel_segments: {rpc_response.error}. Falling back to direct query.")
- # Fallback direct query (sorting might be string-based if not careful with params->>'segment_index')
- # Manual sorting after fetch will be needed for direct query fallback.
- direct_query_response = SUPABASE_CLIENT.table(PG_TABLE_NAME)\
- .select("params,output_location")\
- .eq("params->>orchestrator_run_id", orchestrator_run_id)\
- .eq("task_type", "travel_segment")\
- .eq("status", STATUS_COMPLETE)\
- .execute() # Add .order() if Supabase client allows direct ordering on JSON cast
- if direct_query_response.data:
- temp_list = []
- for item_direct in direct_query_response.data:
- try: s_idx = int(item_direct["params"].get("segment_index")); temp_list.append((s_idx, item_direct["output_location"]))
- except: dprint(f"Stitch Supabase fallback: Error parsing segment_index for item {item_direct.get('task_id')}")
- temp_list.sort(key=lambda x: x[0] if x[0] is not None else -1) # Sort manually
- completed_segment_outputs_from_db.extend(temp_list)
- elif direct_query_response.error:
- dprint(f"Stitch Supabase fallback query error: {direct_query_response.error}")
- except Exception as e_supabase_fetch:
- dprint(f"Stitch Supabase: Exception during segment fetch: {e_supabase_fetch}. May result in incomplete stitch.")
-
- # Filter and add valid paths from DB query results
- for seg_idx, loc_str in completed_segment_outputs_from_db:
- if loc_str and Path(loc_str).exists():
- segment_video_paths_for_stitch.append(str(Path(loc_str).resolve()))
- dprint(f"Stitch: Adding segment {seg_idx} video: {loc_str}")
- else:
- dprint(f"[WARNING] Stitch: Segment {seg_idx} output path '{loc_str}' is missing or invalid. It will be excluded.")
-
- total_videos_for_stitch = (1 if initial_continued_video_path_str and Path(initial_continued_video_path_str).exists() else 0) + num_expected_new_segments
- if len(segment_video_paths_for_stitch) < total_videos_for_stitch:
- # This is a warning because some segments might have legitimately failed and been skipped by their handlers.
- # The stitcher should proceed with what it has, unless it has zero or one video when multiple were expected.
- dprint(f"[WARNING] Stitch: Expected {total_videos_for_stitch} videos for stitch, but found {len(segment_video_paths_for_stitch)}. Stitching with available videos.")
-
- if not segment_video_paths_for_stitch:
- raise ValueError("Stitch: No valid segment videos found to stitch.")
- if len(segment_video_paths_for_stitch) == 1 and total_videos_for_stitch > 1:
- dprint(f"Stitch: Only one video segment found ({segment_video_paths_for_stitch[0]}) but {total_videos_for_stitch} were expected. Using this single video as the 'stitched' output.")
- # No actual stitching needed, just move/copy this single video to final dest.
-
- # --- 3. Stitching (Crossfade or Concatenate) ---
- temp_stitched_video_output_path = stitch_processing_dir / f"stitched_raw_{orchestrator_run_id}.mp4"
- current_stitched_video_path: Path | None = None
-
- if len(segment_video_paths_for_stitch) == 1:
- # If only one video, use it directly (copy to processing dir for consistency)
- shutil.copy2(segment_video_paths_for_stitch[0], temp_stitched_video_output_path)
- current_stitched_video_path = temp_stitched_video_output_path
- else: # More than one video, proceed with stitching logic
- # Determine if any actual overlap values require frame-based crossfade
- # expanded_frame_overlaps corresponds to NEW segments. If continuing, the overlap between continued video and 1st new segment is overlaps[0].
- # If not continuing, overlap between new_seg0 and new_seg1 is overlaps[0], etc.
- # The number of overlaps is num_new_segments if not continuing, or num_new_segments if continuing (as the first overlap is defined then).
- num_stitch_points = len(segment_video_paths_for_stitch) - 1
- actual_overlaps_for_stitching = []
- if initial_continued_video_path_str: # Continued video exists
- actual_overlaps_for_stitching = expanded_frame_overlaps[:num_stitch_points]
- else: # No continued video, overlaps are between the new segments themselves
- actual_overlaps_for_stitching = expanded_frame_overlaps[:num_stitch_points]
-
- any_positive_overlap = any(o > 0 for o in actual_overlaps_for_stitching)
-
- if any_positive_overlap:
- dprint(f"Stitch: Using cross-fade due to overlap values: {actual_overlaps_for_stitching}")
- all_segment_frames_lists = [sm_extract_frames_from_video(p) for p in segment_video_paths_for_stitch]
- if not all(f_list is not None and len(f_list)>0 for f_list in all_segment_frames_lists):
- raise ValueError("Stitch: Frame extraction failed for one or more segments during cross-fade prep.")
+ if generation_success:
+ try:
+ video_paths_str = [str(p.resolve()) for p in generated_video_files]
+ sm_stitch_videos_ffmpeg(video_paths_str, str(stitched_video_path.resolve()))
+
+ if stitched_video_path.exists() and stitched_video_path.stat().st_size > 0:
+ generated_video_file = stitched_video_path
+ dprint(f"[Task ID: {task_id}] Successfully stitched video segments to: {generated_video_file}")
+ for segment_file in generated_video_files:
+ try:
+ segment_file.unlink()
+ except OSError as e_clean:
+ print(f"[WARNING Task ID: {task_id}] Could not delete segment file {segment_file}: {e_clean}")
+ else:
+ print(f"[ERROR Task ID: {task_id}] Stitching failed. Output file '{stitched_video_path}' not created or is empty.")
+ generation_success = False
+ except Exception as e_stitch:
+ print(f"[ERROR Task ID: {task_id}] An exception occurred during video stitching: {e_stitch}")
+ traceback.print_exc()
+ generation_success = False
+
+ if generated_video_file and generation_success:
+ dprint(f"[Task ID: {task_id}] Processing final video: {generated_video_file}")
- final_stitched_frames = []
- # Add first segment up to its overlap point
- overlap_for_first_join = actual_overlaps_for_stitching[0] if actual_overlaps_for_stitching else 0
- if len(all_segment_frames_lists[0]) > overlap_for_first_join:
- final_stitched_frames.extend(all_segment_frames_lists[0][:-overlap_for_first_join if overlap_for_first_join > 0 else len(all_segment_frames_lists[0])])
- else: # First segment is shorter than or equal to its overlap, take all of it (it will be fully part of the crossfade)
- final_stitched_frames.extend(all_segment_frames_lists[0])
-
- for i in range(num_stitch_points): # Iterate through join points
- frames_prev_segment = all_segment_frames_lists[i]
- frames_curr_segment = all_segment_frames_lists[i+1]
- current_overlap_val = actual_overlaps_for_stitching[i]
-
- if current_overlap_val > 0:
- faded_frames = sm_cross_fade_overlap_frames(frames_prev_segment, frames_curr_segment, current_overlap_val, "linear_sharp", crossfade_sharp_amt)
- final_stitched_frames.extend(faded_frames)
- else: # Zero overlap, simple concatenation of remainder of prev and all of curr (excluding its own next overlap)
- # This case (zero overlap in a crossfade path) implies we should take the rest of frames_prev_segment if not already fully added
- # The logic for adding first segment and then iterating joins should handle this correctly.
- # If overlap is zero, the previous segment's tail (if any, beyond its own *next* overlap) needs to be added if not already.
- # However, the current loop structure processes joins. If current_overlap_val is 0, it means no crossfade for *this* join.
- # The remainder of frames_prev_segment (if it wasn't fully consumed by *its* previous crossfade) was added before this loop point.
- # So, we just need to add frames_curr_segment up to *its* next overlap point.
- pass # Handled by adding the tail of current segment below.
+ # Use custom output directory if provided, otherwise use default logic
+ if custom_output_dir:
+ target_dir = Path(custom_output_dir)
+ target_dir.mkdir(parents=True, exist_ok=True)
- # Add the tail of the current segment (frames_curr_segment) up to *its* next overlap point
- if (i + 1) < num_stitch_points: # If this is NOT the join before the very last segment
- overlap_for_next_join_of_curr = actual_overlaps_for_stitching[i+1]
- start_index_for_curr_tail = current_overlap_val # Start after the current join's faded part
- end_index_for_curr_tail = len(frames_curr_segment) - (overlap_for_next_join_of_curr if overlap_for_next_join_of_curr > 0 else 0)
- if end_index_for_curr_tail > start_index_for_curr_tail:
- final_stitched_frames.extend(frames_curr_segment[start_index_for_curr_tail : end_index_for_curr_tail])
- else: # This IS the join before the very last segment, so add all remaining frames of last segment after its fade-in
- start_index_for_last_segment_tail = current_overlap_val
- if len(frames_curr_segment) > start_index_for_last_segment_tail:
- final_stitched_frames.extend(frames_curr_segment[start_index_for_last_segment_tail:])
-
- if not final_stitched_frames: raise ValueError("Stitch: No frames produced after cross-fade logic.")
- current_stitched_video_path = sm_create_video_from_frames_list(final_stitched_frames, temp_stitched_video_output_path, final_fps, parsed_res_wh)
+ final_video_path = sm_get_unique_target_path(
+ target_dir,
+ task_id,
+ generated_video_file.suffix
+ )
+
+ # For custom output dir, create a relative path for DB storage
+ try:
+ output_location_to_db = str(final_video_path.relative_to(Path.cwd()))
+ except ValueError:
+ output_location_to_db = str(final_video_path.resolve())
+
+ try:
+ shutil.move(str(generated_video_file), str(final_video_path))
+ print(f"[Task ID: {task_id}] Output video saved to: {final_video_path.resolve()} (DB location: {output_location_to_db})")
+ except Exception as e_move:
+ print(f"[ERROR Task ID: {task_id}] Failed to move video to custom output directory: {e_move}")
+ generation_success = False
+
+ else:
+ # Use the generalized upload-aware output handling for all other cases
+ final_video_path, initial_db_location = prepare_output_path_with_upload(
+ task_id=task_id,
+ filename=generated_video_file.name,
+ main_output_dir_base=main_output_dir_base,
+ dprint=dprint
+ )
+
+ try:
+ shutil.move(str(generated_video_file), str(final_video_path))
+ dprint(f"[Task ID: {task_id}] Moved generated video to: {final_video_path}")
+
+ # Handle Supabase upload (if configured) and get final location for DB
+ output_location_to_db = upload_and_get_final_output_location(
+ final_video_path,
+ task_id,
+ initial_db_location,
+ dprint=dprint
+ )
+
+ print(f"[Task ID: {task_id}] Output video saved to: {final_video_path.resolve()} (DB location: {output_location_to_db})")
+
+ except Exception as e_move:
+ print(f"[ERROR Task ID: {task_id}] Failed to move video to final destination: {e_move}")
+ generation_success = False
else:
- dprint(f"Stitch: Using simple FFmpeg concatenation as no positive overlaps specified.")
- # Ensure common_utils.stitch_videos_ffmpeg is imported or accessible
- from Wan2GP.sm_functions.common_utils import stitch_videos_ffmpeg as sm_stitch_videos_ffmpeg
- if sm_stitch_videos_ffmpeg(segment_video_paths_for_stitch, str(temp_stitched_video_output_path)):
- current_stitched_video_path = temp_stitched_video_output_path
- else: raise RuntimeError("Stitch: Simple FFmpeg concatenation failed.")
-
- if not current_stitched_video_path or not current_stitched_video_path.exists():
- raise RuntimeError(f"Stitch: Stitching process failed, output video not found at {current_stitched_video_path}")
+ print(f"[WARNING Task ID: {task_id}] Generation reported success, but no .mp4 file found in {temp_output_dir}")
+ generation_success = False
- # --- 4. Optional Upscaling ---
- upscale_factor = full_orchestrator_payload.get("upscale_factor", 0.0)
- upscale_model_name = full_orchestrator_payload.get("upscale_model_name")
- current_final_video_path_before_move = current_stitched_video_path # Start with stitched path
-
- if isinstance(upscale_factor, (float, int)) and upscale_factor > 1.0 and upscale_model_name:
- dprint(f"Stitch: Upscaling (x{upscale_factor}) video {current_stitched_video_path.name} using model {upscale_model_name}")
- upscaled_vid_basename = f"stitched_upscaled_{upscale_factor:.1f}x_{orchestrator_run_id}"
- # Upscaled video also goes into the stitch_processing_dir first
- temp_upscaled_video_path = sm_get_unique_target_path(stitch_processing_dir, upscaled_vid_basename, current_stitched_video_path.suffix)
-
- original_frames_count, _ = sm_get_video_frame_count_and_fps(str(current_stitched_video_path))
- if original_frames_count is None or original_frames_count == 0:
- raise ValueError(f"Stitch: Cannot get frame count or 0 frames for video {current_stitched_video_path} before upscaling.")
+ try:
+ shutil.rmtree(temp_output_dir)
+ dprint(f"[Task ID: {task_id}] Cleaned up temporary directory: {temp_output_dir}")
+ except Exception as e_clean:
+ print(f"[WARNING Task ID: {task_id}] Failed to clean up temporary directory {temp_output_dir}: {e_clean}")
- target_width_upscaled = int(parsed_res_wh[0] * upscale_factor)
- target_height_upscaled = int(parsed_res_wh[1] * upscale_factor)
-
- upscale_sub_task_id = sm_generate_unique_task_id(f"upscale_stitch_{orchestrator_run_id}_")
- # Upscale sub-task outputs to its own folder under main_output_dir_base/travel_run_X/stitch_Y/upscale_Z for clarity
- upscale_sub_task_output_dir_relative = (stitch_processing_dir / f"upscale_assets_{upscale_sub_task_id[:8]}").relative_to(main_output_dir_base)
-
- upscale_payload = {
- "task_id": upscale_sub_task_id,
- "model": upscale_model_name,
- "video_source_path": str(current_stitched_video_path.resolve()), # Absolute path to the stitched video
- "resolution": f"{target_width_upscaled}x{target_height_upscaled}",
- "frames": original_frames_count, # Upscaler needs total frames
- "prompt": full_orchestrator_payload.get("original_task_args",{}).get("upscale_prompt", "cinematic, masterpiece, high detail, 4k"),
- "seed": full_orchestrator_payload.get("seed_for_upscale", full_orchestrator_payload.get("seed_base", 12345) + 5000),
- "output_sub_dir": str(upscale_sub_task_output_dir_relative) # Relative path for where upscaler saves
- }
- # Add other relevant upscale params from original_task_args if present in full_orchestrator_payload
- # e.g., specific LoRAs, guidance for upscaler model if applicable
-
- db_path_for_upscale_add = SQLITE_DB_PATH if DB_TYPE == "sqlite" else None
- # Upscaler engine can be specified in orchestrator payload, defaults to main execution engine
- upscaler_engine_to_use = stitch_params.get("execution_engine_for_upscale", full_orchestrator_payload["execution_engine"])
-
- sm_add_task_to_db(upscale_payload, db_path_for_upscale_add, task_type=upscaler_engine_to_use, task_id_override=upscale_sub_task_id)
- print(f"Stitch Task {stitch_task_id_str}: Enqueued upscale sub-task {upscale_sub_task_id} ({upscaler_engine_to_use}). Waiting...")
-
- from Wan2GP.sm_functions.common_utils import poll_task_status as sm_poll_status_direct # Ensure direct import
- poll_interval_ups = full_orchestrator_payload.get("poll_interval", 15)
- poll_timeout_ups = full_orchestrator_payload.get("poll_timeout_upscale", full_orchestrator_payload.get("poll_timeout", 30 * 60) * 2) # Longer timeout for upscale
-
- upscaled_video_location_from_db = sm_poll_status_direct(
- task_id_to_poll=upscale_sub_task_id,
- db_path_str=db_path_for_upscale_add,
- poll_interval_seconds=poll_interval_ups,
- timeout_seconds=poll_timeout_ups
+ # --- Chaining Logic ---
+ # This block is now executed for any successful primitive task that doesn't return early.
+ if generation_success:
+ chaining_result_path_override = None
+
+ if task_params_dict.get("travel_chain_details"):
+ dprint(f"WGP Task {task_id} is part of a travel sequence. Attempting to chain.")
+ chain_success, chain_message, final_path_from_chaining = tbi._handle_travel_chaining_after_wgp(
+ wgp_task_params=task_params_dict,
+ actual_wgp_output_video_path=output_location_to_db,
+ wgp_mod=wgp_mod,
+ image_download_dir=image_download_dir,
+ dprint=dprint
)
-
- if upscaled_video_location_from_db and Path(upscaled_video_location_from_db).exists():
- dprint(f"Stitch: Upscale sub-task {upscale_sub_task_id} completed. Output: {upscaled_video_location_from_db}")
- # Copy the upscaled video from its sub-task output dir to the temp_upscaled_video_path in stitch_processing_dir
- shutil.copy2(upscaled_video_location_from_db, str(temp_upscaled_video_path))
- current_final_video_path_before_move = temp_upscaled_video_path
- else:
- print(f"[WARNING] Stitch Task {stitch_task_id_str}: Upscale sub-task {upscale_sub_task_id} failed or output missing. Using non-upscaled video.")
- # current_final_video_path_before_move remains the stitched, non-upscaled video path
- elif upscale_factor > 1.0 and not upscale_model_name:
- dprint(f"Stitch: Upscale factor {upscale_factor} > 1.0 but no upscale_model_name provided. Skipping upscale.")
-
- # --- 5. Final Output Naming and Moving ---
- final_video_name_base = f"travel_final_{orchestrator_run_id}"
- if upscale_factor > 1.0 and "upscaled" in str(current_final_video_path_before_move.name).lower():
- final_video_name_base = f"travel_final_upscaled_{upscale_factor:.1f}x_{orchestrator_run_id}"
+ if chain_success:
+ chaining_result_path_override = final_path_from_chaining
+ dprint(f"Task {task_id}: Travel chaining successful. Message: {chain_message}")
+ else:
+ print(f"[ERROR Task ID: {task_id}] Travel sequence chaining failed after WGP completion: {chain_message}. The raw WGP output '{output_location_to_db}' will be used for this task's DB record.")
- # Final video is placed in main_output_dir_base (e.g. ./steerable_motion_output/)
- # NOT under current_run_base_output_dir (which is ./steerable_motion_output/travel_run_XYZ/)
- final_output_destination_path = sm_get_unique_target_path(main_output_dir_base, final_video_name_base, current_final_video_path_before_move.suffix)
- shutil.move(str(current_final_video_path_before_move), str(final_output_destination_path))
- final_video_location_for_db = str(final_output_destination_path.resolve())
- print(f"Stitch Task {stitch_task_id_str}: Final Video produced at: {final_video_location_for_db}")
- stitch_success = True
-
- except Exception as e_stitch_main:
- msg = f"Stitch Task {stitch_task_id_str}: Main process failed: {e_stitch_main}"
- print(f"[ERROR Task {stitch_task_id_str}]: {msg}"); traceback.print_exc()
- stitch_success = False
- final_video_location_for_db = msg # Store truncated error for DB
-
- finally:
- # --- 6. Cleanup ---
- # Cleanup logic depends on debug_mode_enabled and skip_cleanup_enabled from full_orchestrator_payload
- debug_mode = full_orchestrator_payload.get("debug_mode_enabled", False)
- skip_cleanup = full_orchestrator_payload.get("skip_cleanup_enabled", False)
- do_cleanup_processing_dir = not debug_mode and not skip_cleanup
- do_cleanup_run_dir = do_cleanup_processing_dir and stitch_success # Only clean run_dir if stitch fully succeeded and processing dir is cleaned
-
- if do_cleanup_processing_dir and stitch_processing_dir.exists():
- dprint(f"Stitch: Cleaning up stitch processing directory: {stitch_processing_dir}")
- try: shutil.rmtree(stitch_processing_dir)
- except Exception as e_c1: dprint(f"Stitch: Error cleaning up {stitch_processing_dir}: {e_c1}")
- else:
- dprint(f"Stitch: Skipping cleanup of stitch processing directory: {stitch_processing_dir} (debug:{debug_mode}, skip_cleanup:{skip_cleanup})")
-
- # Cleanup the entire run directory (containing all segment folders) only if stitch was successful AND not in debug/skip_cleanup mode.
- if do_cleanup_run_dir and current_run_base_output_dir.exists():
- dprint(f"Stitch: Full run successful. Cleaning up run directory: {current_run_base_output_dir}")
- try: shutil.rmtree(current_run_base_output_dir)
- except Exception as e_c2: dprint(f"Stitch: Error cleaning up run directory {current_run_base_output_dir}: {e_c2}")
- elif current_run_base_output_dir.exists(): # current_run_base_output_dir exists but conditions for cleanup not met
- dprint(f"Stitch: Skipping cleanup of run directory: {current_run_base_output_dir} (stitch_success:{stitch_success}, debug:{debug_mode}, skip_cleanup:{skip_cleanup})")
-
- return stitch_success, final_video_location_for_db
+ elif task_params_dict.get("different_perspective_chain_details"):
+ # SM_RESTRUCTURE_FIX: Prevent double-chaining. This is now handled in the 'generate_openpose' block.
+ # The only other task type that can have these details is 'wgp', which is the intended target for this block.
+ if task_type != 'generate_openpose':
+ dprint(f"Task {task_id} is part of a different_perspective sequence. Attempting to chain.")
+
+ chain_success, chain_message, final_path_from_chaining = dp._handle_different_perspective_chaining(
+ completed_task_params=task_params_dict,
+ task_output_path=output_location_to_db,
+ dprint=dprint
+ )
+ if chain_success:
+ chaining_result_path_override = final_path_from_chaining
+ dprint(f"Task {task_id}: Different Perspective chaining successful. Message: {chain_message}")
+ else:
+ print(f"[ERROR Task ID: {task_id}] Different Perspective sequence chaining failed: {chain_message}. This may halt the sequence.")
+
+
+ if chaining_result_path_override:
+ path_to_check_existence: Path | None = None
+ if db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH and isinstance(chaining_result_path_override, str) and chaining_result_path_override.startswith("files/"):
+ sqlite_db_parent = Path(db_ops.SQLITE_DB_PATH).resolve().parent
+ path_to_check_existence = (sqlite_db_parent / "public" / chaining_result_path_override).resolve()
+ dprint(f"Task {task_id}: Chaining returned SQLite relative path '{chaining_result_path_override}'. Resolved to '{path_to_check_existence}' for existence check.")
+ elif chaining_result_path_override:
+ path_to_check_existence = Path(chaining_result_path_override).resolve()
+ dprint(f"Task {task_id}: Chaining returned absolute-like path '{chaining_result_path_override}'. Resolved to '{path_to_check_existence}' for existence check.")
+
+ if path_to_check_existence and path_to_check_existence.exists() and path_to_check_existence.is_file():
+ is_output_path_different = str(chaining_result_path_override) != str(output_location_to_db)
+ if is_output_path_different:
+ dprint(f"Task {task_id}: Chaining modified output path for DB. Original: {output_location_to_db}, New: {chaining_result_path_override} (Checked file: {path_to_check_existence})")
+ output_location_to_db = chaining_result_path_override
+ elif chaining_result_path_override is not None:
+ print(f"[WARNING Task ID: {task_id}] Chaining reported success, but final path '{chaining_result_path_override}' (checked as '{path_to_check_existence}') is invalid or not a file. Using original WGP output '{output_location_to_db}' for DB.")
+
+
+ # Ensure orchestrator tasks use their DB row ID as task_id so that
+ # downstream sub-tasks reference the right row when updating status.
+ if task_type in {"travel_orchestrator", "different_perspective_orchestrator"}:
+ # Overwrite/insert the canonical task_id inside params to the DB row's ID
+ task_params_dict["task_id"] = task_id
+
+ print(f"--- Finished task ID: {task_id} (Success: {generation_success}) ---")
+ return generation_success, output_location_to_db
-# --- End SM_RESTRUCTURE ---
# -----------------------------------------------------------------------------
# 7. Main server loop
@@ -2211,101 +739,220 @@ def _get_sqlite_segments_for_stitch(conn):
def main():
load_dotenv() # Load .env file variables into environment
- global DB_TYPE, PG_TABLE_NAME, SQLITE_DB_PATH, SUPABASE_URL, SUPABASE_SERVICE_KEY, SUPABASE_VIDEO_BUCKET, SUPABASE_CLIENT
+ global DB_TYPE, SQLITE_DB_PATH, SUPABASE_CLIENT, SUPABASE_VIDEO_BUCKET
# Determine DB type from environment variables
env_db_type = os.getenv("DB_TYPE", "sqlite").lower()
env_pg_table_name = os.getenv("POSTGRES_TABLE_NAME", "tasks")
env_supabase_url = os.getenv("SUPABASE_URL")
env_supabase_key = os.getenv("SUPABASE_SERVICE_KEY")
- env_supabase_bucket = os.getenv("SUPABASE_VIDEO_BUCKET", "videos")
- env_comfyui_output_path = os.getenv("COMFYUI_OUTPUT_PATH") # Load ComfyUI output path
+ env_supabase_anon_key = os.getenv("SUPABASE_ANON_KEY")
+ env_supabase_bucket = os.getenv("SUPABASE_VIDEO_BUCKET", "image_uploads")
env_sqlite_db_path = os.getenv("SQLITE_DB_PATH_ENV") # Read SQLite DB path from .env
cli_args = parse_args()
+ # ------------------------------------------------------------------
+ # Auto-enable file logging when --debug flag is present
+ # ------------------------------------------------------------------
+ if cli_args.debug and not cli_args.save_logging:
+ from datetime import datetime
+ default_logs_dir = Path("logs")
+ default_logs_dir.mkdir(parents=True, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ cli_args.save_logging = str(default_logs_dir / f"debug_{timestamp}.log")
+ # ------------------------------------------------------------------
+
+ # Handle --worker parameter for worker-specific logging
+ if cli_args.worker and not cli_args.save_logging:
+ default_logs_dir = Path("logs")
+ default_logs_dir.mkdir(parents=True, exist_ok=True)
+ cli_args.save_logging = str(default_logs_dir / f"{cli_args.worker}.log")
+ # ------------------------------------------------------------------
+
+ # --- Handle --delete-db flag ---
+ if cli_args.delete_db:
+ db_file_to_delete = cli_args.db_file
+ env_sqlite_db_path = os.getenv("SQLITE_DB_PATH_ENV")
+ if env_sqlite_db_path:
+ db_file_to_delete = env_sqlite_db_path
+
+ db_files_to_remove = [
+ db_file_to_delete,
+ f"{db_file_to_delete}-wal",
+ f"{db_file_to_delete}-shm"
+ ]
+
+ for db_file in db_files_to_remove:
+ if Path(db_file).exists():
+ try:
+ Path(db_file).unlink()
+ print(f"[DELETE-DB] Removed: {db_file}")
+ except Exception as e:
+ print(f"[DELETE-DB ERROR] Could not remove {db_file}: {e}")
+
+ print("[DELETE-DB] Database cleanup complete. Starting fresh.")
+ # --- End delete-db handling ---
+
+ # --- Setup logging to file if requested ---
+ if cli_args.save_logging:
+ import logging
+
+ log_file_path = Path(cli_args.save_logging)
+ log_file_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Create a custom stream that writes to both console and file
+ class DualWriter:
+ def __init__(self, log_file_path):
+ self.terminal = sys.stdout
+ self.log_file = open(log_file_path, 'a', encoding='utf-8')
+
+ def write(self, message):
+ self.terminal.write(message)
+ self.log_file.write(message)
+ self.log_file.flush() # Ensure immediate write
+
+ def flush(self):
+ self.terminal.flush()
+ self.log_file.flush()
+
+ def close(self):
+ if hasattr(self, 'log_file'):
+ self.log_file.close()
+
+ # Redirect stdout to our dual writer
+ sys.stdout = DualWriter(log_file_path)
+
+ if cli_args.worker:
+ print(f"[WORKER LOGGING] Worker '{cli_args.worker}' output will be saved to: {log_file_path.resolve()}")
+ else:
+ print(f"[LOGGING] All output will be saved to: {log_file_path.resolve()}")
+
+ # Ensure cleanup on exit
+ import atexit
+ atexit.register(lambda: hasattr(sys.stdout, 'close') and sys.stdout.close())
+ # --- End logging setup ---
+
# --- Configure DB Type and Connection Globals ---
- # This block sets DB_TYPE, SQLITE_DB_PATH, SUPABASE_CLIENT, PG_TABLE_NAME etc.
- if env_db_type == "supabase" and env_supabase_url and env_supabase_key:
+ # This block sets DB_TYPE, SQLITE_DB_PATH, SUPABASE_CLIENT, etc. in the db_ops module
+ if cli_args.db_type == "supabase" and cli_args.supabase_url and cli_args.supabase_access_token:
try:
- temp_supabase_client = create_client(env_supabase_url, env_supabase_key)
- if temp_supabase_client:
- DB_TYPE = "supabase"
- PG_TABLE_NAME = env_pg_table_name
- SUPABASE_URL = env_supabase_url
- SUPABASE_SERVICE_KEY = env_supabase_key
- SUPABASE_VIDEO_BUCKET = env_supabase_bucket
- SUPABASE_CLIENT = temp_supabase_client
- # Initial print about using Supabase will be done after migrations
- else:
- raise Exception("Supabase client creation returned None")
+ # --- New PAT/JWT-based Supabase client initialization ---
+ # For PATs and user JWTs, we primarily use edge functions and avoid direct
+ # Supabase client authentication which expects specific JWT formats.
+ # We'll create a client with the service key for internal operations,
+ # but use the access token in headers for edge function calls.
+
+ # Use service key for admin operations if available, otherwise anon key
+ client_key = env_supabase_key or cli_args.supabase_anon_key or env_supabase_anon_key
+
+ if not client_key:
+ raise ValueError("Need either service key or anon key for Supabase client initialization.")
+
+ dprint(f"Supabase: Initializing client for {cli_args.supabase_url}.")
+ temp_supabase_client = create_client(cli_args.supabase_url, client_key)
+
+ # For PATs and user tokens, we'll primarily rely on edge functions
+ # The access token will be passed in Authorization headers
+ dprint(f"Supabase: Client initialized. Access token will be used in edge function calls.")
+
+ # --- Assign to db_ops globals on success ---
+ db_ops.DB_TYPE = "supabase"
+ db_ops.PG_TABLE_NAME = env_pg_table_name
+ db_ops.SUPABASE_URL = cli_args.supabase_url
+ db_ops.SUPABASE_SERVICE_KEY = env_supabase_key # Keep service key if present
+ db_ops.SUPABASE_VIDEO_BUCKET = env_supabase_bucket
+ db_ops.SUPABASE_CLIENT = temp_supabase_client
+ # Store the access token for use in Edge Function calls
+ db_ops.SUPABASE_ACCESS_TOKEN = cli_args.supabase_access_token
+
+ # Local globals for convenience
+ DB_TYPE = "supabase"
+ SUPABASE_CLIENT = temp_supabase_client
+ SUPABASE_VIDEO_BUCKET = env_supabase_bucket
+
except Exception as e:
- print(f"[ERROR] Failed to initialize Supabase client: {e}. Check SUPABASE_URL and SUPABASE_SERVICE_KEY.")
+ print(f"[ERROR] Failed to initialize Supabase client: {e}")
+ traceback.print_exc()
print("Falling back to SQLite due to Supabase client initialization error.")
- DB_TYPE = "sqlite"
- # Determine SQLite path for fallback: .env, then CLI, then default
- SQLITE_DB_PATH = env_sqlite_db_path if env_sqlite_db_path else cli_args.db_file
- elif env_db_type == "sqlite":
- DB_TYPE = "sqlite"
- # Determine SQLite path: .env, then CLI, then default
- SQLITE_DB_PATH = env_sqlite_db_path if env_sqlite_db_path else cli_args.db_file
- else: # Default to sqlite if .env DB_TYPE is unrecognized or not set
- print(f"DB_TYPE '{env_db_type}' in .env is not recognized or not set. Defaulting to SQLite.")
+ db_ops.DB_TYPE = "sqlite"
+ db_ops.SQLITE_DB_PATH = env_sqlite_db_path if env_sqlite_db_path else cli_args.db_file
+ else: # Default to sqlite if .env DB_TYPE is unrecognized or not set, or if it's explicitly "sqlite"
+ if cli_args.db_type != "sqlite":
+ print(f"DB_TYPE '{cli_args.db_type}' in CLI args is not recognized. Defaulting to SQLite.")
+ db_ops.DB_TYPE = "sqlite"
+ db_ops.SQLITE_DB_PATH = env_sqlite_db_path if env_sqlite_db_path else cli_args.db_file
DB_TYPE = "sqlite"
- # Determine SQLite path: .env, then CLI, then default
- SQLITE_DB_PATH = env_sqlite_db_path if env_sqlite_db_path else cli_args.db_file
+ SQLITE_DB_PATH = db_ops.SQLITE_DB_PATH
# --- End DB Type Configuration ---
# --- Run DB Migrations ---
# Must be after DB type/config is determined but before DB schema is strictly enforced by init_db or heavy use.
- _run_db_migrations()
+ # Note: Migrations completed - now using Edge Functions exclusively
+ # db_ops._run_db_migrations() # Commented out - migration to Edge Functions complete
# --- End DB Migrations ---
- global COMFYUI_OUTPUT_PATH_CONFIG
- if env_comfyui_output_path:
- COMFYUI_OUTPUT_PATH_CONFIG = Path(env_comfyui_output_path)
- if not COMFYUI_OUTPUT_PATH_CONFIG.is_dir():
- print(f"[WARNING] COMFYUI_OUTPUT_PATH '{COMFYUI_OUTPUT_PATH_CONFIG}' from .env is not a valid directory. ComfyUI tasks may fail to retrieve outputs.")
- # COMFYUI_OUTPUT_PATH_CONFIG = None # Or let it proceed and fail in the handler
- else:
- print(f"ComfyUI output path configured: {COMFYUI_OUTPUT_PATH_CONFIG}")
- else:
- print("[INFO] COMFYUI_OUTPUT_PATH not set in .env. ComfyUI tasks will require this to retrieve outputs.")
+ # --- Handle --migrate-only flag --- (Section 6)
+ if cli_args.migrate_only:
+ print("Database migrations complete (called with --migrate-only). Exiting.")
+ sys.exit(0)
+ # --- End --migrate-only handler ---
+
main_output_dir = Path(cli_args.main_output_dir)
main_output_dir.mkdir(parents=True, exist_ok=True)
print(f"WanGP Headless Server Started.")
- if DB_TYPE == "supabase":
- print(f"Monitoring Supabase (PostgreSQL backend) table: {PG_TABLE_NAME}")
+ if cli_args.worker:
+ print(f"Worker ID: {cli_args.worker}")
+ if db_ops.DB_TYPE == "supabase":
+ print(f"Monitoring Supabase (PostgreSQL backend) table: {db_ops.PG_TABLE_NAME}")
else: # SQLite
- print(f"Monitoring SQLite database: {SQLITE_DB_PATH}")
+ print(f"Monitoring SQLite database: {db_ops.SQLITE_DB_PATH}")
print(f"Outputs will be saved under: {main_output_dir}")
print(f"Polling interval: {cli_args.poll_interval} seconds.")
# Initialize database
- if DB_TYPE == "supabase":
- init_db_supabase() # New call, uses globals
+ # Supabase table/schema assumed to exist; skip initialization RPC
+ if db_ops.DB_TYPE == "supabase":
+ dprint("Supabase: Skipping init_db_supabase – table assumed present.")
else: # SQLite
- init_db(SQLITE_DB_PATH) # Existing SQLite init
+ db_ops.init_db() # Existing SQLite init
# Activate global debug switch early so that all subsequent code paths can use dprint()
global debug_mode
debug_mode = cli_args.debug
+ db_ops.debug_mode = cli_args.debug # Also set it in the db_ops module
dprint("Verbose debug logging enabled.")
- original_argv = sys.argv.copy()
- sys.argv = ["Wan2GP/wgp.py"]
+ # Before importing Wan2GP.wgp we need to ensure that a CPU-only PyTorch build
+ # does not crash when third-party code unconditionally calls CUDA helpers.
+ try:
+ import torch # Local import to limit overhead when torch is missing
+
+ if not torch.cuda.is_available():
+ # Monkey-patch *early* to stub out problematic functions that trigger
+ # torch.cuda initialisation on CPU-only builds (which raises
+ # "Torch not compiled with CUDA enabled"). This is safe because the
+ # downstream Wan2GP code only checks the return value tuple.
+ def _dummy_get_device_capability(device=None):
+ return (0, 0)
+
+ torch.cuda.get_device_capability = _dummy_get_device_capability # type: ignore[attr-defined]
+
+ # Some libraries check for the attribute rather than calling it, so
+ # we also advertise zero GPUs.
+ torch.cuda.device_count = lambda: 0 # type: ignore[attr-defined]
+ except ImportError:
+ # torch not installed – Wan2GP import will fail later anyway.
+ pass
+
+ # Ensure Wan2GP sees a clean argv list and Gradio functions are stubbed
+ sys.argv = ["Wan2GP/wgp.py"] # Prevent wgp.py from parsing headless.py CLI args
patch_gradio()
-
- # Add Wan2GP directory to Python path and import wgp directly
- wan2gp_path = Path(__file__).parent / "Wan2GP"
- if str(wan2gp_path) not in sys.path:
- sys.path.insert(0, str(wan2gp_path))
-
- import wgp as wgp_mod
-
- sys.argv = original_argv
+
+ # Import wgp from the Wan2GP sub-package
+ from Wan2GP import wgp as wgp_mod
# Apply wgp.py global config overrides
if cli_args.wgp_attention_mode is not None: wgp_mod.attention_mode = cli_args.wgp_attention_mode
@@ -2338,7 +985,7 @@ def main():
print(" Ensure models are manually placed in 'ckpts' or a video generation task runs first to trigger downloads.")
# --- End early model download ---
- db_path_str = str(SQLITE_DB_PATH) if DB_TYPE == "sqlite" else PG_TABLE_NAME # Use consistent string path for db functions
+ db_path_str = str(db_ops.SQLITE_DB_PATH) if db_ops.DB_TYPE == "sqlite" else db_ops.PG_TABLE_NAME # Use consistent string path for db functions
# --- Ensure LoRA directories expected by wgp.py exist, especially for LTXV ---
try:
@@ -2364,48 +1011,147 @@ def main():
# --- End LoRA directory check ---
try:
+ # --- Add a one-time diagnostic log for task counts ---
+ if db_ops.DB_TYPE == "sqlite" and debug_mode:
+ try:
+ counts = db_ops.get_initial_task_counts()
+ if counts:
+ total_tasks, queued_tasks = counts
+ dprint(f"SQLite Initial State: Total tasks in '{db_ops.PG_TABLE_NAME}': {total_tasks}. Tasks with status '{db_ops.STATUS_QUEUED}': {queued_tasks}.")
+ except Exception as e_diag:
+ dprint(f"SQLite Diagnostic Error: Could not get initial task counts: {e_diag}")
+ # --- End one-time diagnostic log ---
+
while True:
task_info = None
current_task_id_for_status_update = None # Used to hold the task_id for status updates
+ current_project_id = None # To hold the project_id for the current task
- if DB_TYPE == "supabase":
- dprint(f"Checking for queued tasks in Supabase (PostgreSQL backend) table {PG_TABLE_NAME} via Supabase RPC...")
- task_info = get_oldest_queued_task_supabase()
+ if db_ops.DB_TYPE == "supabase":
+ dprint(f"Checking for queued tasks in Supabase (PostgreSQL backend) table {db_ops.PG_TABLE_NAME} via Supabase RPC...")
+ task_info = db_ops.get_oldest_queued_task_supabase(worker_id=cli_args.worker)
+ dprint(f"Supabase task_info: {task_info}") # ADDED DPRINT
if task_info:
current_task_id_for_status_update = task_info["task_id"]
- # Status is already set to IN_PROGRESS by func_claim_task RPC
+ # Status is already set to IN_PROGRESS by claim-next-task Edge Function
else: # SQLite
- dprint(f"Checking for queued tasks in SQLite {SQLITE_DB_PATH}...")
- task_info = get_oldest_queued_task(SQLITE_DB_PATH)
+ dprint(f"Checking for queued tasks in SQLite {db_ops.SQLITE_DB_PATH}...")
+ task_info = db_ops.get_oldest_queued_task()
+ dprint(f"SQLite task_info: {task_info}") # ADDED DPRINT
if task_info:
current_task_id_for_status_update = task_info["task_id"]
- update_task_status(SQLITE_DB_PATH, current_task_id_for_status_update, STATUS_IN_PROGRESS)
+ db_ops.update_task_status(current_task_id_for_status_update, db_ops.STATUS_IN_PROGRESS)
if not task_info:
dprint("No queued tasks found. Sleeping...")
- time.sleep(cli_args.poll_interval)
+ if db_ops.DB_TYPE == "sqlite":
+ # Wait until either the WAL/db file changes or the normal
+ # poll interval elapses. This reduces perceived latency
+ # without hammering the database.
+ _wait_for_sqlite_change(db_ops.SQLITE_DB_PATH, cli_args.poll_interval)
+ else:
+ time.sleep(cli_args.poll_interval)
continue
# current_task_data = task_info["params"] # Params are already a dict
current_task_params = task_info["params"]
current_task_type = task_info["task_type"] # Retrieve task_type
+ current_project_id = task_info.get("project_id") # Get project_id, might be None if not returned
+
+ # This fallback logic remains, but it's less likely to be needed
+ # if get_oldest_queued_task and its supabase equivalent are reliable.
+ if current_project_id is None and current_task_id_for_status_update:
+ dprint(f"Project ID not directly available for task {current_task_id_for_status_update}. Attempting to fetch manually...")
+ if db_ops.DB_TYPE == "supabase" and db_ops.SUPABASE_CLIENT:
+ try:
+ # Using 'id' as the column name for task_id based on Supabase schema conventions seen elsewhere (e.g. init_db)
+ response = db_ops.SUPABASE_CLIENT.table(db_ops.PG_TABLE_NAME)\
+ .select("project_id")\
+ .eq("id", current_task_id_for_status_update)\
+ .single()\
+ .execute()
+ if response.data and response.data.get("project_id"):
+ current_project_id = response.data["project_id"]
+ dprint(f"Successfully fetched project_id '{current_project_id}' for task {current_task_id_for_status_update} from Supabase.")
+ else:
+ dprint(f"Could not fetch project_id for task {current_task_id_for_status_update} from Supabase. Response data: {response.data}, error: {response.error}")
+ except Exception as e_fetch_proj_id:
+ dprint(f"Exception while fetching project_id for {current_task_id_for_status_update} from Supabase: {e_fetch_proj_id}")
+ elif db_ops.DB_TYPE == "sqlite": # Should have been fetched, but as a fallback
+ # This fallback no longer needs its own db connection logic.
+ # A new helper could be added to db_ops if this is truly needed,
+ # but for now, we assume the primary fetch works.
+ dprint(f"Project_id was not fetched for {current_task_id_for_status_update} from SQLite. This is unexpected.")
- print(f"Found task: {current_task_id_for_status_update} of type: {current_task_type}")
+
+ # Critical check: project_id is NOT NULL for sub-tasks created by orchestrator
+ if current_project_id is None and current_task_type == "travel_orchestrator":
+ print(f"[CRITICAL ERROR] Task {current_task_id_for_status_update} (travel_orchestrator) has no project_id. Sub-tasks cannot be created. Skipping task.")
+ # Update status to FAILED to prevent re-processing this broken state
+ error_message_for_db = "Failed: Orchestrator task missing project_id, cannot create sub-tasks."
+ if db_ops.DB_TYPE == "supabase":
+ db_ops.update_task_status_supabase(current_task_id_for_status_update, db_ops.STATUS_FAILED, error_message_for_db)
+ else:
+ db_ops.update_task_status(current_task_id_for_status_update, db_ops.STATUS_FAILED, error_message_for_db)
+ time.sleep(1) # Brief pause
+ continue # Skip to next polling cycle
+
+ print(f"Found task: {current_task_id_for_status_update} of type: {current_task_type}, Project ID: {current_project_id}")
# Status already set to IN_PROGRESS if task_info is not None
- task_succeeded, output_location = process_single_task(wgp_mod, current_task_params, main_output_dir, current_task_type)
+ # Inserted: define segment_image_download_dir from task params if available
+ segment_image_download_dir = current_task_params.get("segment_image_download_dir")
+
+ # Ensure orchestrator tasks propagate the DB row ID as their canonical task_id *before* processing
+ if current_task_type in {"travel_orchestrator", "different_perspective_orchestrator"}:
+ current_task_params["task_id"] = current_task_id_for_status_update
+ if "orchestrator_details" in current_task_params:
+ current_task_params["orchestrator_details"]["orchestrator_task_id"] = current_task_id_for_status_update
+
+ task_succeeded, output_location = process_single_task(
+ wgp_mod, current_task_params, main_output_dir, current_task_type, current_project_id,
+ image_download_dir=segment_image_download_dir,
+ apply_reward_lora=cli_args.apply_reward_lora,
+ colour_match_videos=cli_args.colour_match_videos,
+ mask_active_frames=cli_args.mask_active_frames
+ )
if task_succeeded:
- if DB_TYPE == "supabase":
- update_task_status_supabase(current_task_id_for_status_update, STATUS_COMPLETE, output_location)
+ # Orchestrator tasks stay "In Progress" until their children report back.
+ orchestrator_types_waiting = {"travel_orchestrator", "different_perspective_orchestrator"}
+
+ if current_task_type in orchestrator_types_waiting:
+ # Keep status as IN_PROGRESS (already set when we claimed the task).
+ # We still store the output message (if any) so operators can see it.
+ db_ops.update_task_status(
+ current_task_id_for_status_update,
+ db_ops.STATUS_IN_PROGRESS,
+ output_location,
+ )
+ print(
+ f"Task {current_task_id_for_status_update} queued child tasks; awaiting completion before finalising."
+ )
else:
- update_task_status(SQLITE_DB_PATH, current_task_id_for_status_update, STATUS_COMPLETE, output_location)
- print(f"Task {current_task_id_for_status_update} completed successfully. Output location: {output_location}")
+ if db_ops.DB_TYPE == "supabase":
+ db_ops.update_task_status_supabase(
+ current_task_id_for_status_update,
+ db_ops.STATUS_COMPLETE,
+ output_location,
+ )
+ else:
+ db_ops.update_task_status(
+ current_task_id_for_status_update,
+ db_ops.STATUS_COMPLETE,
+ output_location,
+ )
+ print(
+ f"Task {current_task_id_for_status_update} completed successfully. Output location: {output_location}"
+ )
else:
- if DB_TYPE == "supabase":
- update_task_status_supabase(current_task_id_for_status_update, STATUS_FAILED, output_location)
+ if db_ops.DB_TYPE == "supabase":
+ db_ops.update_task_status_supabase(current_task_id_for_status_update, db_ops.STATUS_FAILED, output_location)
else:
- update_task_status(SQLITE_DB_PATH, current_task_id_for_status_update, STATUS_FAILED, output_location)
+ db_ops.update_task_status(current_task_id_for_status_update, db_ops.STATUS_FAILED, output_location)
print(f"Task {current_task_id_for_status_update} failed. Review logs for errors. Output location recorded: {output_location if output_location else 'N/A'}")
time.sleep(1) # Brief pause before checking for the next task
@@ -2421,253 +1167,46 @@ def main():
print(f"Error during offloadobj release: {e_release}")
print("Server stopped.")
-# -----------------------------------------------------------------------------
-# Supabase Storage Helper
-# -----------------------------------------------------------------------------
-def upload_to_supabase_storage(local_file_path: Path, object_name_in_bucket: str, bucket_name: str) -> str | None:
- """Uploads a file to Supabase storage and returns its public URL."""
- if not SUPABASE_CLIENT:
- print("[ERROR] Supabase client not initialized. Cannot upload.")
- return None
-
- try:
-
- with open(local_file_path, 'rb') as f:
- # The object name can include paths, e.g., "videos/task_123.mp4"
- res = SUPABASE_CLIENT.storage.from_(bucket_name).upload(
- path=object_name_in_bucket,
- file=f,
- file_options={"cache-control": "3600", "upsert": "true"} # Upsert to overwrite if exists
- )
-
- dprint(f"Supabase upload response data: {res.json() if hasattr(res, 'json') else res}")
- # Get public URL
- public_url_response = SUPABASE_CLIENT.storage.from_(bucket_name).get_public_url(object_name_in_bucket)
-
- # public_url_response is a string directly
- dprint(f"Supabase get_public_url response: {public_url_response}")
- if public_url_response:
- print(f"INFO: Successfully uploaded {local_file_path.name} to Supabase bucket '{bucket_name}' as '{object_name_in_bucket}'. URL: {public_url_response}")
- return public_url_response
- else:
- print(f"[ERROR] Failed to get public URL for {object_name_in_bucket} in Supabase bucket '{bucket_name}'. Upload may have succeeded.")
- return None # Or construct a presumed URL if your bucket is public and path is known
-
- except Exception as e:
- print(f"[ERROR] Failed to upload {local_file_path.name} to Supabase: {e}")
- traceback.print_exc()
- return None
-
-# -----------------------------------------------------------------------------
-# PostgreSQL Specific DB Functions (Now using Supabase RPC)
-# Renamed to reflect Supabase usage more directly
-# -----------------------------------------------------------------------------
-
-def init_db_supabase(): # Renamed from init_db_postgres
- """Initializes the PostgreSQL tasks table via Supabase RPC if it doesn't exist."""
- if not SUPABASE_CLIENT:
- print("[ERROR] Supabase client not initialized. Cannot initialize database table.")
- sys.exit(1)
- try:
- # RPC call to the SQL function func_initialize_tasks_table
- # IMPORTANT: The func_initialize_tasks_table SQL function itself
- # must be updated to include "task_type TEXT NOT NULL" in its
- # CREATE TABLE statement for the specified p_table_name.
- SUPABASE_CLIENT.rpc("func_initialize_tasks_table", {"p_table_name": PG_TABLE_NAME}).execute()
- print(f"Supabase RPC: Table '{PG_TABLE_NAME}' initialization requested.")
- # Note: RPC for DDL might not return specific confirmation beyond successful execution.
- # You might need to add a SELECT to confirm table existence if strict feedback is needed.
- except Exception as e: # Broader exception for Supabase/PostgREST errors
- print(f"[ERROR] Supabase RPC for table initialization failed: {e}")
- traceback.print_exc()
- sys.exit(1)
-
-def get_oldest_queued_task_supabase(): # Renamed from get_oldest_queued_task_postgres
- """Fetches the oldest task via Supabase RPC using func_claim_task."""
- if not SUPABASE_CLIENT:
- print("[ERROR] Supabase client not initialized. Cannot get task.")
- return None
- try:
- worker_id = f"worker_{os.getpid()}" # Example worker ID
- dprint(f"DEBUG get_oldest_queued_task_supabase: About to call RPC func_claim_task.")
- dprint(f"DEBUG get_oldest_queued_task_supabase: PG_TABLE_NAME = '{PG_TABLE_NAME}' (type: {type(PG_TABLE_NAME)})")
- dprint(f"DEBUG get_oldest_queued_task_supabase: worker_id = '{worker_id}' (type: {type(worker_id)})")
-
- response = SUPABASE_CLIENT.rpc(
- "func_claim_task",
- {"p_table_name": PG_TABLE_NAME, "p_worker_id": worker_id}
- ).execute()
-
- dprint(f"Supabase RPC func_claim_task response data: {response.data}")
-
- if response.data and len(response.data) > 0:
- task_data = response.data[0] # RPC should return a single row or empty
- # Ensure the RPC returns task_id_out, params_out, and now task_type_out
- if task_data.get("task_id_out") and task_data.get("params_out") is not None and task_data.get("task_type_out") is not None:
- dprint(f"Supabase RPC: Claimed task {task_data['task_id_out']} of type {task_data['task_type_out']}")
- return {"task_id": task_data["task_id_out"], "params": task_data["params_out"], "task_type": task_data["task_type_out"]}
- else:
- dprint("Supabase RPC: func_claim_task returned but no task was claimed or required fields (task_id_out, params_out, task_type_out) are missing.")
- return None
- else:
- dprint("Supabase RPC: No task claimed or empty response from func_claim_task.")
- return None
- except Exception as e:
- print(f"[ERROR] Supabase RPC func_claim_task failed: {e}")
- traceback.print_exc()
- return None
-def update_task_status_supabase(task_id_str, status_str, output_location_val=None): # Renamed from update_task_status_postgres
- """Updates a task's status via Supabase RPC using func_update_task_status."""
- if not SUPABASE_CLIENT:
- print("[ERROR] Supabase client not initialized. Cannot update task status.")
- return
- try:
- params = {
- "p_table_name": PG_TABLE_NAME,
- "p_task_id": task_id_str,
- "p_status": status_str
- }
- if output_location_val is not None:
- params["p_output_location"] = output_location_val
-
- SUPABASE_CLIENT.rpc("func_update_task_status", params).execute()
- dprint(f"Supabase RPC: Updated status of task {task_id_str} to {status_str}. Output: {output_location_val if output_location_val else 'N/A'}")
- except Exception as e:
- print(f"[ERROR] Supabase RPC func_update_task_status for {task_id_str} failed: {e}")
-
-def _migrate_supabase_schema():
- """Applies necessary schema migrations to an existing Supabase/PostgreSQL database via RPC."""
- if not SUPABASE_CLIENT:
- print("[ERROR] Supabase Migration: Supabase client not initialized. Cannot run migration.")
- return
-
- dprint(f"Supabase Migration: Requesting schema migration via RPC 'func_migrate_tasks_for_task_type' for table {PG_TABLE_NAME}...")
- try:
- # IMPORTANT: The user must create the SQL function 'func_migrate_tasks_for_task_type'
- # in their Supabase/PostgreSQL database.
- # This function should take p_table_name as TEXT argument and:
- # 1. Check if the 'task_type' column exists in the p_table_name.
- # 2. If not, execute: ALTER TABLE {p_table_name} ADD COLUMN task_type TEXT;
- # (Add as nullable first to allow adding to existing tables with data).
- # 3. Populate the new 'task_type' column for existing rows where it's NULL,
- # by extracting the type from the 'params' JSONB column. Example:
- # UPDATE {p_table_name}
- # SET task_type = params->>'task_type' -- Or other logic if task type was stored differently
- # WHERE task_type IS NULL AND params IS NOT NULL AND params->>'task_type' IS NOT NULL;
- # (Adjust the params->>'task_type' according to how task type was previously stored in params).
- # 4. The func_initialize_tasks_table RPC should already ensure task_type is NOT NULL for new tables.
-
- response = SUPABASE_CLIENT.rpc("func_migrate_tasks_for_task_type", {"p_table_name": PG_TABLE_NAME}).execute()
-
- # Improved response handling based on Supabase Python client v2+ structure
- if response.error:
- print(f"[ERROR] Supabase Migration: RPC 'func_migrate_tasks_for_task_type' returned an error: {response.error.message} (Code: {response.error.code}, Details: {response.error.details})")
- elif response.data:
- dprint(f"Supabase Migration: RPC 'func_migrate_tasks_for_task_type' executed. Response data: {response.data}")
- else:
- dprint("Supabase Migration: RPC 'func_migrate_tasks_for_task_type' executed. (No specific data or error in response, check RPC logs if issues)")
-
- except Exception as e:
- print(f"[ERROR] Supabase Migration: Failed to execute RPC 'func_migrate_tasks_for_task_type': {e}")
- traceback.print_exc()
+def _wait_for_sqlite_change(db_path_str: str, timeout_seconds: int):
+ """Block up to timeout_seconds waiting for mtime change on db / -wal / -shm.
-def _run_db_migrations():
- """Runs database migrations based on the configured DB_TYPE."""
- dprint(f"DB Migrations: Running for DB_TYPE: {DB_TYPE}")
- if DB_TYPE == "sqlite":
- if SQLITE_DB_PATH:
- _migrate_sqlite_schema(SQLITE_DB_PATH)
- else:
- print("[ERROR] DB Migration: SQLITE_DB_PATH not set. Skipping SQLite migration.")
- elif DB_TYPE == "supabase":
- if SUPABASE_CLIENT and PG_TABLE_NAME:
- _migrate_supabase_schema()
- else:
- print("[ERROR] DB Migration: Supabase client or PG_TABLE_NAME not configured. Skipping Supabase migration.")
- else:
- dprint(f"DB Migrations: No migration logic for DB_TYPE '{DB_TYPE}'. Skipping migrations.")
-
-# Helper to query DB for a specific task's output (needed by segment handler)
-def get_task_output_location_from_db(task_id_to_find: str) -> str | None:
- dprint(f"Querying DB for output location of task: {task_id_to_find}")
- if DB_TYPE == "sqlite":
- def _get_op(conn):
- cursor = conn.cursor()
- # Ensure we only get tasks that are actually complete with an output
- cursor.execute("SELECT output_location FROM tasks WHERE task_id = ? AND status = ? AND output_location IS NOT NULL",
- (task_id_to_find, STATUS_COMPLETE))
- row = cursor.fetchone()
- return row[0] if row else None
- try:
- return execute_sqlite_with_retry(SQLITE_DB_PATH, _get_op)
- except Exception as e:
- print(f"Error querying SQLite for task output {task_id_to_find}: {e}")
- return None
- elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
- try:
- # This assumes an RPC function `func_get_task_details` exists or similar direct query access.
- # The RPC should return at least `output_location` and `status`.
- # Example: response = SUPABASE_CLIENT.rpc("func_get_task_details", {"p_task_id": task_id_to_find}).execute()
- # If direct query:
- response = SUPABASE_CLIENT.table(PG_TABLE_NAME)\
- .select("output_location, status")\
- .eq("task_id", task_id_to_find)\
- .execute()
-
- if response.data and len(response.data) > 0:
- task_details = response.data[0]
- if task_details.get("status") == STATUS_COMPLETE and task_details.get("output_location"):
- return task_details.get("output_location")
- else:
- dprint(f"Task {task_id_to_find} found but not complete or no output_location. Status: {task_details.get('status')}")
- return None
- else:
- dprint(f"Task {task_id_to_find} not found in Supabase.")
- return None
- except Exception as e:
- print(f"Error querying Supabase for task output {task_id_to_find}: {e}")
- traceback.print_exc()
- return None
- dprint(f"DB type {DB_TYPE} not supported or client not init for get_task_output_location_from_db")
- return None
-
-# Function to find the task_id of a previous segment by its index and run_id
-# This is crucial for _handle_travel_segment_task to find its dependency.
-def get_previous_segment_task_id(orchestrator_run_id: str, target_segment_index: int) -> str | None:
- dprint(f"Querying DB for task_id of segment_index {target_segment_index} for run_id {orchestrator_run_id}")
- if DB_TYPE == "sqlite":
- def _get_op(conn):
- cursor = conn.cursor()
- # Assumes segment_task_id is stored in `params` or is the main `task_id`
- # And that `segment_index` and `orchestrator_run_id` are in `params`
- query = f"SELECT task_id FROM tasks WHERE json_extract(params, '$.orchestrator_run_id') = ? AND json_extract(params, '$.segment_index') = ? AND task_type = 'travel_segment' AND status = ?"
- cursor.execute(query, (orchestrator_run_id, target_segment_index, STATUS_COMPLETE))
- row = cursor.fetchone()
- return row[0] if row else None
- try:
- return execute_sqlite_with_retry(SQLITE_DB_PATH, _get_op)
- except Exception as e:
- print(f"Error querying SQLite for previous segment task_id: {e}")
- return None
- elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
+ This lets the headless server react almost immediately when another process
+ commits a transaction that only touches the WAL file. We fall back to a
+ normal sleep when the auxiliary files do not exist (e.g. before first
+ write).
+ """
+ related_paths = [
+ Path(db_path_str),
+ Path(f"{db_path_str}-wal"),
+ Path(f"{db_path_str}-shm"),
+ ]
+
+ # Snapshot the most-recent modification time we can observe now.
+ last_mtime = 0.0
+ for p in related_paths:
try:
- response = SUPABASE_CLIENT.table(PG_TABLE_NAME)\
- .select("task_id")\
- .eq("params->>orchestrator_run_id", orchestrator_run_id)\
- .eq("params->>segment_index", target_segment_index)\
- .eq("task_type", "travel_segment")\
- .eq("status", STATUS_COMPLETE)\
- .limit(1)\
- .execute()
- if response.data and len(response.data) > 0:
- return response.data[0]["task_id"]
- return None
- except Exception as e:
- print(f"Error querying Supabase for previous segment task_id: {e}")
- return None
- return None
+ last_mtime = max(last_mtime, p.stat().st_mtime)
+ except FileNotFoundError:
+ # Aux file not created yet – ignore.
+ pass
+
+ # Poll in small increments until something changes or timeout expires.
+ poll_step = 0.25 # seconds
+ waited = 0.0
+ while waited < timeout_seconds:
+ time.sleep(poll_step)
+ waited += poll_step
+ for p in related_paths:
+ try:
+ if p.stat().st_mtime > last_mtime:
+ return # Change detected – return immediately
+ except FileNotFoundError:
+ # File still missing – keep waiting
+ pass
+ # Timed out – return control to caller
+ return
if __name__ == "__main__":
main()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 26a770994..d67f17857 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,66 @@
-# Headless server specific dependencies:
-python-dotenv # For .env file management
-supabase # For Supabase DB (via RPC) and Storage
-requests # For downloading LoRAs
-Pillow # For image processing in headless tasks
+# Combined requirements for Headless-Wan2GP project
+# Includes both headless server dependencies and Wan2GP model dependencies
+
+# ===== Core ML/AI Dependencies =====
+torch>=2.4.0
+torchvision>=0.19.0
+transformers==4.51.3
+tokenizers>=0.20.3
+diffusers>=0.31.0
+accelerate>=1.1.1
+peft>=0.15.0 # Updated from 0.14.0 to fix compatibility with diffusers
+huggingface_hub>=0.25.0 # For reliable model/LoRA downloads with checksums
+safetensors>=0.4.0 # For LoRA integrity verification
+timm
+einops
+numpy>=1.23.5,<2
+
+# ===== Computer Vision Dependencies =====
+opencv-python>=4.9.0.80
+opencv-python-headless # For headless environments
+segment-anything
+rembg[gpu]==2.0.65
+
+# ===== Media Processing =====
+imageio
+imageio-ffmpeg
+moviepy==1.0.3
+av
+librosa
+mutagen
+decord
+Pillow # For image processing in headless tasks
mediapipe
-requests # For downloading videos from URLs in travel_between_images
\ No newline at end of file
+
+# ===== Web UI and API =====
+gradio==5.23.0
+fastapi # Often needed for gradio apps
+python-multipart # For file uploads
+
+# ===== Database and Storage =====
+supabase # For Supabase DB (via RPC) and Storage
+python-dotenv # For .env file management
+
+# ===== Utilities =====
+requests # For downloading LoRAs and videos
+replicate # For magic edit tasks via Replicate API
+tqdm
+easydict
+ftfy
+dashscope
+omegaconf
+hydra-core
+loguru
+sentencepiece
+pydantic==2.10.6
+
+# ===== GPU/Performance Dependencies =====
+onnxruntime-gpu
+mmgp==3.4.8
+# flash_attn # Uncomment if needed for flash attention
+
+# ===== Visualization =====
+matplotlib
+
+# ===== Legacy/Compatibility =====
+dotenv # Older dotenv package (consider removing if python-dotenv works)
\ No newline at end of file
diff --git a/samples/1.png b/samples/1.png
new file mode 100644
index 000000000..358ab0300
Binary files /dev/null and b/samples/1.png differ
diff --git a/samples/2.png b/samples/2.png
new file mode 100644
index 000000000..082cc01e3
Binary files /dev/null and b/samples/2.png differ
diff --git a/samples/3.png b/samples/3.png
new file mode 100644
index 000000000..66eb1a86a
Binary files /dev/null and b/samples/3.png differ
diff --git a/samples/image_1.png b/samples/image_1.png
new file mode 100644
index 000000000..d63af0997
Binary files /dev/null and b/samples/image_1.png differ
diff --git a/samples/image_2.png b/samples/image_2.png
new file mode 100644
index 000000000..84ddb4ab0
Binary files /dev/null and b/samples/image_2.png differ
diff --git a/samples/image_3.png b/samples/image_3.png
new file mode 100644
index 000000000..edfcb6838
Binary files /dev/null and b/samples/image_3.png differ
diff --git a/samples/pose.png b/samples/pose.png
new file mode 100644
index 000000000..a425ae270
Binary files /dev/null and b/samples/pose.png differ
diff --git a/sm_functions/common_utils.py b/sm_functions/common_utils.py
deleted file mode 100644
index 62c9f0ec1..000000000
--- a/sm_functions/common_utils.py
+++ /dev/null
@@ -1,1117 +0,0 @@
-"""Common utility functions and constants for steerable_motion tasks."""
-
-import json
-import os
-import shutil
-import sqlite3
-import subprocess
-import tempfile
-import time
-import traceback
-import uuid
-from pathlib import Path
-
-import cv2 # pip install opencv-python
-import mediapipe as mp # pip install mediapipe
-import numpy as np # pip install numpy
-from PIL import Image, ImageDraw, ImageFont # pip install Pillow
-
-# --- Global Debug Mode ---
-# This will be set by the main script (steerable_motion.py)
-DEBUG_MODE = False
-
-# --- Constants for DB interaction and defaults ---
-STATUS_QUEUED = "Queued"
-STATUS_IN_PROGRESS = "In Progress"
-STATUS_COMPLETE = "Complete"
-STATUS_FAILED = "Failed"
-DEFAULT_DB_TABLE_NAME = "tasks"
-# DEFAULT_MODEL_NAME = "vace_14B" # Defined in steerable_motion.py's argparser
-# DEFAULT_SEGMENT_FRAMES = 81 # Defined in steerable_motion.py's argparser
-# DEFAULT_FPS_HELPERS = 25 # Defined in steerable_motion.py's argparser
-# DEFAULT_SEED = 12345 # Defined in steerable_motion.py's argparser
-
-# --- Debug / Verbose Logging Helper ---
-def dprint(msg: str):
- """Print a debug message if DEBUG_MODE is enabled."""
- if DEBUG_MODE:
- print(f"[DEBUG SM-COMMON {time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
-
-# --- Helper Functions ---
-
-def parse_resolution(res_str: str) -> tuple[int, int]:
- """Parses 'WIDTHxHEIGHT' string to (width, height) tuple."""
- try:
- w, h = map(int, res_str.split('x'))
- if w <= 0 or h <= 0:
- raise ValueError("Width and height must be positive.")
- return w, h
- except ValueError as e:
- raise ValueError(f"Resolution string must be in WIDTHxHEIGHT format with positive integers (e.g., '960x544'), got {res_str}. Error: {e}")
-
-def generate_unique_task_id(prefix="sm_task_") -> str:
- """Generates a unique task ID."""
- return f"{prefix}{uuid.uuid4().hex[:12]}"
-
-def image_to_frame(image_path: str | Path, target_size: tuple[int, int]) -> np.ndarray | None:
- """Loads an image, resizes it, and converts to BGR NumPy array for OpenCV."""
- try:
- img = Image.open(image_path).convert("RGB")
- img = img.resize(target_size, Image.Resampling.LANCZOS)
- return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
- except FileNotFoundError:
- print(f"Error: Image file not found at {image_path}")
- return None
- except Exception as e:
- print(f"Error loading or processing image {image_path}: {e}")
- return None
-
-def create_color_frame(size: tuple[int, int], color_bgr: tuple[int, int, int] = (0, 0, 0)) -> np.ndarray:
- """Creates a single color BGR frame (default black)."""
- height, width = size[1], size[0] # size is (width, height)
- frame = np.full((height, width, 3), color_bgr, dtype=np.uint8)
- return frame
-
-def create_video_from_frames_list(
- frames_list: list[np.ndarray],
- output_path: str | Path,
- fps: int,
- resolution: tuple[int, int] # width, height
-):
- """Creates an MP4 video from a list of NumPy BGR frames."""
- output_path = Path(output_path)
- output_path.parent.mkdir(parents=True, exist_ok=True)
-
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
- out = None
- try:
- out = cv2.VideoWriter(str(output_path), fourcc, float(fps), resolution)
- if not out.isOpened():
- raise IOError(f"Could not open video writer for {output_path}")
-
- for frame_np in frames_list:
- if frame_np.shape[1] != resolution[0] or frame_np.shape[0] != resolution[1]:
- frame_np_resized = cv2.resize(frame_np, resolution, interpolation=cv2.INTER_AREA)
- out.write(frame_np_resized)
- else:
- out.write(frame_np)
- print(f"Generated video: {output_path} ({len(frames_list)} frames)")
- finally:
- if out:
- out.release()
-
-def add_task_to_db(task_payload: dict, db_path: str | Path, task_type_str: str):
- conn = sqlite3.connect(str(db_path))
- cursor = conn.cursor()
- try:
- # Ensure task_type is not in the params dict as it's now a separate column
- # This is important if the caller still includes it in task_payload by old habit.
- # The task_payload itself is still stored in the 'params' column.
- # The task_id is expected to be in task_payload.
- # The task_id for the DB record should be consistent.
- # If task_payload (which becomes params) already has a task_id, use that.
- # Otherwise, a task_id needs to be generated/provided.
- # The original line was: task_payload["task_id"]
- # This means the task_id to be inserted was expected to be part of the task_payload dict.
-
- # Let's assume the task_id passed to this function (if any) or generated by it
- # is the PRIMARY KEY for the DB.
- # The `task_payload` argument to this function is what gets stored in the `params` column.
-
- current_task_id = task_payload.get("task_id") # This is the task_id from the original payload structure.
- if not current_task_id:
- # This case should ideally not happen if `steerable_motion.py` prepares `task_payload` correctly.
- # For safety, one might generate one, but it's better to ensure the caller provides it.
- # Sticking to the original structure where task_payload["task_id"] was used:
- raise ValueError("task_id must be present in task_payload for add_task_to_db")
-
- headless_params_dict = task_payload.copy() # Work on a copy
- if "task_type" in headless_params_dict:
- del headless_params_dict["task_type"] # Remove if it exists, to avoid redundancy
-
- # The task_id for the DB record should be consistent.
- # If task_payload (which becomes params) already has a task_id, use that.
- # Otherwise, a task_id needs to be generated/provided.
- # The original line was: task_payload["task_id"]
- # This means the task_id to be inserted was expected to be part of the task_payload dict.
-
- # Let's assume the task_id passed to this function (if any) or generated by it
- # is the PRIMARY KEY for the DB.
- # The `task_payload` argument to this function is what gets stored in the `params` column.
-
- params_json_for_db = json.dumps(headless_params_dict)
-
- cursor.execute(
- f"INSERT INTO {DEFAULT_DB_TABLE_NAME} (task_id, params, task_type, status) VALUES (?, ?, ?, ?)",
- (current_task_id, params_json_for_db, task_type_str, STATUS_QUEUED)
- )
- conn.commit()
- print(f"Task {current_task_id} (Type: {task_type_str}) added to database {db_path}.")
- except sqlite3.Error as e:
- # Use current_task_id in error message if available
- task_id_for_error = task_payload.get("task_id", "UNKNOWN_TASK_ID")
- print(f"SQLite error when adding task {task_id_for_error} (Type: {task_type_str}): {e}")
- raise
- finally:
- conn.close()
-
-def poll_task_status(task_id: str, db_path: str | Path, poll_interval_seconds: int = 10, timeout_seconds: int = 1800) -> str | None:
- """Polls the DB for task completion and returns the output_location."""
- print(f"Polling for completion of task {task_id} (timeout: {timeout_seconds}s)...")
- start_time = time.time()
- last_status_print_time = 0
-
- while True:
- current_time = time.time()
- if current_time - start_time > timeout_seconds:
- print(f"Error: Timeout polling for task {task_id} after {timeout_seconds} seconds.")
- return None
-
- conn = sqlite3.connect(str(db_path))
- conn.row_factory = sqlite3.Row
- cursor = conn.cursor()
- try:
- cursor.execute(f"SELECT status, output_location FROM {DEFAULT_DB_TABLE_NAME} WHERE task_id = ?", (task_id,))
- row = cursor.fetchone()
- except sqlite3.Error as e:
- print(f"SQLite error while polling task {task_id}: {e}. Retrying...")
- conn.close()
- time.sleep(min(poll_interval_seconds, 5)) # Shorter sleep on DB error
- continue
- finally:
- conn.close()
-
- if row:
- status = row["status"]
- output_location = row["output_location"]
-
- if current_time - last_status_print_time > poll_interval_seconds * 2 : # Print status periodically
- print(f"Task {task_id}: Status = {status} (Output: {output_location if output_location else 'N/A'})")
- last_status_print_time = current_time
-
- if status == STATUS_COMPLETE:
- if output_location:
- print(f"Task {task_id} completed successfully. Output: {output_location}")
- return output_location
- else:
- print(f"Error: Task {task_id} is COMPLETE but output_location is missing. Assuming failure.")
- return None
- elif status == STATUS_FAILED:
- print(f"Error: Task {task_id} failed.")
- return None
- elif status not in [STATUS_QUEUED, STATUS_IN_PROGRESS]:
- print(f"Warning: Task {task_id} has unknown status '{status}'. Treating as error.")
- return None
- else:
- if current_time - last_status_print_time > poll_interval_seconds * 2 :
- print(f"Task {task_id}: Not found in DB yet or status pending...")
- last_status_print_time = current_time
-
- time.sleep(poll_interval_seconds)
-
-def extract_video_segment_ffmpeg(
- input_video_path: str | Path,
- output_video_path: str | Path,
- start_frame_index: int, # 0-indexed
- num_frames_to_keep: int,
- input_fps: float, # FPS of the input video for accurate -ss calculation
- resolution: tuple[int,int]
-):
- """Extracts a video segment using FFmpeg with stream copy if possible."""
- dprint(f"EXTRACT_VIDEO_SEGMENT_FFMPEG: Called with input='{input_video_path}', output='{output_video_path}', start_idx={start_frame_index}, num_frames={num_frames_to_keep}, input_fps={input_fps}")
- if num_frames_to_keep <= 0:
- print(f"Warning: num_frames_to_keep is {num_frames_to_keep} for {output_video_path} (FFmpeg). Nothing to extract.")
- dprint("EXTRACT_VIDEO_SEGMENT_FFMPEG: num_frames_to_keep is 0 or less, returning.")
- Path(output_video_path).unlink(missing_ok=True)
- return
-
- input_video_path = Path(input_video_path)
- output_video_path = Path(output_video_path)
- output_video_path.parent.mkdir(parents=True, exist_ok=True)
-
- start_time_seconds = start_frame_index / input_fps
-
- cmd = [
- 'ffmpeg',
- '-y',
- '-ss', str(start_time_seconds),
- '-i', str(input_video_path.resolve()),
- '-vframes', str(num_frames_to_keep),
- '-an',
- str(output_video_path.resolve())
- ]
-
- dprint(f"EXTRACT_VIDEO_SEGMENT_FFMPEG: Running command: {' '.join(cmd)}")
- try:
- process = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
- dprint(f"EXTRACT_VIDEO_SEGMENT_FFMPEG: Successfully extracted segment to {output_video_path}")
- if process.stderr:
- dprint(f"FFmpeg stderr (for {output_video_path}):\n{process.stderr}")
- if not output_video_path.exists() or output_video_path.stat().st_size == 0:
- print(f"Error: FFmpeg command for {output_video_path} apparently succeeded but output file is missing or empty.")
- dprint(f"FFmpeg command for {output_video_path} produced no output. stdout:\n{process.stdout}\nstderr:\n{process.stderr}")
-
- except subprocess.CalledProcessError as e:
- print(f"Error during FFmpeg segment extraction for {output_video_path}:")
- print("FFmpeg command:", ' '.join(e.cmd))
- if e.stdout: print("FFmpeg stdout:\n", e.stdout)
- if e.stderr: print("FFmpeg stderr:\n", e.stderr)
- dprint(f"FFmpeg extraction failed for {output_video_path}. Error: {e}")
- Path(output_video_path).unlink(missing_ok=True)
- except FileNotFoundError:
- print("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in your PATH.")
- dprint("FFmpeg command not found during segment extraction.")
- raise
-
-def stitch_videos_ffmpeg(video_paths_list: list[str | Path], output_path: str | Path):
- output_path = Path(output_path)
- output_path.parent.mkdir(parents=True, exist_ok=True)
-
- if not video_paths_list:
- print("No videos to stitch.")
- return
-
- valid_video_paths = []
- for p in video_paths_list:
- resolved_p = Path(p).resolve()
- if resolved_p.exists() and resolved_p.stat().st_size > 0:
- valid_video_paths.append(resolved_p)
- else:
- print(f"Warning: Video segment {resolved_p} is missing or empty. Skipping from stitch list.")
-
- if not valid_video_paths:
- print("No valid video segments found to stitch after checks.")
- return
-
- with tempfile.TemporaryDirectory(prefix="ffmpeg_concat_") as tmpdir:
- filelist_path = Path(tmpdir) / "ffmpeg_filelist.txt"
- with open(filelist_path, 'w', encoding='utf-8') as f:
- for video_path in valid_video_paths:
- f.write(f"file '{video_path.as_posix()}'\n")
-
- cmd = [
- 'ffmpeg', '-y', '-f', 'concat', '-safe', '0',
- '-i', str(filelist_path),
- '-c', 'copy', str(output_path)
- ]
-
- print(f"Running ffmpeg to stitch videos: {' '.join(cmd)}")
- try:
- process = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
- print(f"Successfully stitched videos into: {output_path}")
- if process.stderr: print("FFmpeg log (stderr):\n", process.stderr)
- except subprocess.CalledProcessError as e:
- print(f"Error during ffmpeg stitching for {output_path}:")
- print("FFmpeg command:", ' '.join(e.cmd))
- if e.stdout: print("FFmpeg stdout:\n", e.stdout)
- if e.stderr: print("FFmpeg stderr:\n", e.stderr)
- raise
- except FileNotFoundError:
- print("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in your PATH.")
- raise
-
-def save_frame_from_video(video_path: Path, frame_index: int, output_image_path: Path, resolution: tuple[int, int]):
- """Extracts a specific frame from a video, resizes, and saves it as an image."""
- dprint(f"SAVE_FRAME_FROM_VIDEO: Input='{video_path}', Index={frame_index}, Output='{output_image_path}', Res={resolution}")
- if not video_path.exists() or video_path.stat().st_size == 0:
- print(f"Error: Video file for frame extraction not found or empty: {video_path}")
- return False
-
- cap = cv2.VideoCapture(str(video_path))
- if not cap.isOpened():
- print(f"Error: Could not open video file: {video_path}")
- return False
-
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
- if frame_index < 0 or frame_index >= total_frames:
- print(f"Error: Frame index {frame_index} is out of bounds for video {video_path} (total frames: {total_frames}).")
- cap.release()
- return False
-
- cap.set(cv2.CAP_PROP_POS_FRAMES, float(frame_index))
- ret, frame = cap.read()
- cap.release()
-
- if not ret or frame is None:
- print(f"Error: Could not read frame {frame_index} from {video_path}.")
- return False
-
- try:
- if frame.shape[1] != resolution[0] or frame.shape[0] != resolution[1]:
- dprint(f"SAVE_FRAME_FROM_VIDEO: Resizing frame from {frame.shape[:2]} to {resolution[:2][::-1]}")
- frame = cv2.resize(frame, resolution, interpolation=cv2.INTER_AREA)
-
- output_image_path.parent.mkdir(parents=True, exist_ok=True)
- cv2.imwrite(str(output_image_path), frame)
- print(f"Successfully saved frame {frame_index} from {video_path} to {output_image_path}")
- return True
- except Exception as e:
- print(f"Error saving frame to {output_image_path}: {e}")
- traceback.print_exc()
- return False
-
-# --- FFMPEG-based specific frame extraction ---
-def extract_specific_frame_ffmpeg(
- input_video_path: str | Path,
- frame_number: int, # 0-indexed
- output_image_path: str | Path,
- input_fps: float # Passed by caller, though not strictly needed for ffmpeg frame index selection using 'eq(n,frame_number)'
-):
- """Extracts a specific frame from a video using FFmpeg and saves it as an image."""
- dprint(f"EXTRACT_SPECIFIC_FRAME_FFMPEG: Input='{input_video_path}', Frame={frame_number}, Output='{output_image_path}'")
- input_video_p = Path(input_video_path)
- output_image_p = Path(output_image_path)
- output_image_p.parent.mkdir(parents=True, exist_ok=True)
-
- if not input_video_p.exists() or input_video_p.stat().st_size == 0:
- print(f"Error: Input video for frame extraction not found or empty: {input_video_p}")
- dprint(f"EXTRACT_SPECIFIC_FRAME_FFMPEG: Input video {input_video_p} not found or empty. Returning False.")
- return False
-
- cmd = [
- 'ffmpeg',
- '-y', # Overwrite output without asking
- '-i', str(input_video_p.resolve()),
- '-vf', f"select=eq(n\,{frame_number})", # Escaped comma for ffmpeg filter syntax
- '-vframes', '1',
- str(output_image_p.resolve())
- ]
-
- dprint(f"EXTRACT_SPECIFIC_FRAME_FFMPEG: Running command: {' '.join(cmd)}")
- try:
- process = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
- dprint(f"EXTRACT_SPECIFIC_FRAME_FFMPEG: Successfully extracted frame {frame_number} to {output_image_p}")
- if process.stderr:
- dprint(f"FFmpeg stderr (for frame extraction to {output_image_p}):\n{process.stderr}")
- if not output_image_p.exists() or output_image_p.stat().st_size == 0:
- print(f"Error: FFmpeg command for frame extraction to {output_image_p} apparently succeeded but output file is missing or empty.")
- dprint(f"FFmpeg command for {output_image_p} (frame extraction) produced no output. stdout:\n{process.stdout}\nstderr:\n{process.stderr}")
- return False
- return True
- except subprocess.CalledProcessError as e:
- print(f"Error during FFmpeg frame extraction for {output_image_p}:")
- print("FFmpeg command:", ' '.join(e.cmd))
- if e.stdout: print("FFmpeg stdout:\n", e.stdout)
- if e.stderr: print("FFmpeg stderr:\n", e.stderr)
- dprint(f"FFmpeg frame extraction failed for {output_image_p}. Error: {e}")
- if output_image_p.exists(): output_image_p.unlink(missing_ok=True)
- return False
- except FileNotFoundError:
- print("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in your PATH.")
- dprint("FFmpeg command not found during frame extraction.")
- raise
-
-# --- FFMPEG-based video concatenation (alternative to stitch_videos_ffmpeg if caller manages temp dir) ---
-def concatenate_videos_ffmpeg(
- video_paths: list[str | Path],
- output_path: str | Path,
- temp_dir_for_list: str | Path # Directory where the list file will be created
-):
- """Concatenates multiple video files into one using FFmpeg, using a provided temp directory for the list file."""
- output_p = Path(output_path)
- output_p.parent.mkdir(parents=True, exist_ok=True)
- temp_dir_p = Path(temp_dir_for_list)
- temp_dir_p.mkdir(parents=True, exist_ok=True)
-
- if not video_paths:
- print("No videos to concatenate.")
- dprint("CONCATENATE_VIDEOS_FFMPEG: No video paths provided. Returning.")
- if output_p.exists(): output_p.unlink(missing_ok=True)
- return
-
- valid_video_paths = []
- for p_item in video_paths:
- resolved_p_item = Path(p_item).resolve()
- if resolved_p_item.exists() and resolved_p_item.stat().st_size > 0:
- valid_video_paths.append(resolved_p_item)
- else:
- print(f"Warning: Video segment {resolved_p_item} for concatenation is missing or empty. Skipping.")
- dprint(f"CONCATENATE_VIDEOS_FFMPEG: Skipping invalid video segment {resolved_p_item}")
-
- if not valid_video_paths:
- print("No valid video segments found to concatenate after checks.")
- dprint("CONCATENATE_VIDEOS_FFMPEG: No valid video segments. Returning.")
- if output_p.exists(): output_p.unlink(missing_ok=True)
- return
-
- filelist_path = temp_dir_p / "ffmpeg_concat_filelist.txt"
- with open(filelist_path, 'w', encoding='utf-8') as f:
- for video_path_item in valid_video_paths:
- f.write(f"file '{video_path_item.as_posix()}'\n") # Use as_posix() for ffmpeg list file
-
- cmd = [
- 'ffmpeg', '-y',
- '-f', 'concat',
- '-safe', '0',
- '-i', str(filelist_path.resolve()),
- '-c', 'copy',
- str(output_p.resolve())
- ]
-
- dprint(f"CONCATENATE_VIDEOS_FFMPEG: Running command: {' '.join(cmd)} with list file {filelist_path}")
- try:
- process = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
- print(f"Successfully concatenated videos into: {output_p}")
- dprint(f"CONCATENATE_VIDEOS_FFMPEG: Success. Output: {output_p}")
- if process.stderr:
- dprint(f"FFmpeg stderr (for concatenation to {output_p}):\n{process.stderr}")
- if not output_p.exists() or output_p.stat().st_size == 0:
- print(f"Warning: FFmpeg concatenation to {output_p} apparently succeeded but output file is missing or empty.")
- dprint(f"FFmpeg command for {output_p} (concatenation) produced no output. stdout:\n{process.stdout}\nstderr:\n{process.stderr}")
- except subprocess.CalledProcessError as e:
- print(f"Error during FFmpeg concatenation for {output_p}:")
- print("FFmpeg command:", ' '.join(e.cmd))
- if e.stdout: print("FFmpeg stdout:\n", e.stdout)
- if e.stderr: print("FFmpeg stderr:\n", e.stderr)
- dprint(f"FFmpeg concatenation failed for {output_p}. Error: {e}")
- if output_p.exists(): output_p.unlink(missing_ok=True)
- raise
- except FileNotFoundError:
- print("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in your PATH.")
- dprint("CONCATENATE_VIDEOS_FFMPEG: ffmpeg command not found.")
- raise
-
-# --- OpenCV-based video properties extraction ---
-def get_video_frame_count_and_fps(video_path: str | Path) -> tuple[int | None, float | None]:
- """Gets frame count and FPS of a video using OpenCV. Returns (None, None) on failure."""
- video_path_str = str(Path(video_path).resolve())
- cap = None
- try:
- cap = cv2.VideoCapture(video_path_str)
- if not cap.isOpened():
- dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Could not open video: {video_path_str}")
- return None, None
-
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
- fps = cap.get(cv2.CAP_PROP_FPS)
-
- # Validate frame_count and fps as they can sometimes be 0 or negative for problematic files/streams
- valid_frame_count = frame_count if frame_count > 0 else None
- valid_fps = fps if fps > 0 else None
-
- if valid_frame_count is None:
- dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Video {video_path_str} reported non-positive frame count: {frame_count}. Treating as unknown.")
- if valid_fps is None:
- dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Video {video_path_str} reported non-positive FPS: {fps}. Treating as unknown.")
-
- dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Video {video_path_str} - Frames: {valid_frame_count}, FPS: {valid_fps}")
- return valid_frame_count, valid_fps
- except Exception as e:
- dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Exception processing {video_path_str}: {e}")
- return None, None
- finally:
- if cap:
- cap.release()
-
-# --- Pose Processing Constants and Helpers (from ComfyUI example) ---
-body_colors = [
- [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
- [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
- [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
-]
-face_color = [255, 255, 255]
-hand_keypoint_color = [0, 0, 255]
-hand_limb_colors = [
- [255,0,0],[255,60,0],[255,120,0],[255,180,0], [180,255,0],[120,255,0],[60,255,0],[0,255,0],
- [0,255,60],[0,255,120],[0,255,180],[0,180,255], [0,120,255],[0,60,255],[0,0,255],[60,0,255],
- [120,0,255],[180,0,255],[255,0,180],[255,0,120]
-]
-
-# MediaPipe Pose connections (33 landmarks, indices 0-32)
-# Based on mp.solutions.pose.POSE_CONNECTIONS
-body_skeleton = [
- (0, 1), (1, 2), (2, 3), (3, 7), # Nose to Left Eye to Left Ear
- (0, 4), (4, 5), (5, 6), (6, 8), # Nose to Right Eye to Right Ear
- (9, 10), # Mouth
- (11, 12), # Shoulders
- (11, 13), (13, 15), (15, 17), (17, 19), (15, 19), (15, 21), # Left Arm and simplified Left Hand (wrist to fingers)
- (12, 14), (14, 16), (16, 18), (18, 20), (16, 20), (16, 22), # Right Arm and simplified Right Hand (wrist to fingers)
- (11, 23), (12, 24), # Connect shoulders to Hips
- (23, 24), # Hip connection
- (23, 25), (25, 27), (27, 29), (29, 31), (27, 31), # Left Leg and Foot
- (24, 26), (26, 28), (28, 30), (30, 32), (28, 32) # Right Leg and Foot
-]
-
-face_skeleton = [] # Draw face dots only, no connections
-
-# MediaPipe Hand connections (21 landmarks per hand, indices 0-20)
-hand_skeleton = [
- (0, 1), (1, 2), (2, 3), (3, 4), # Thumb
- (0, 5), (5, 6), (6, 7), (7, 8), # Index finger
- (0, 9), (9, 10), (10, 11), (11, 12), # Middle finger
- (0, 13), (13, 14), (14, 15), (15, 16), # Ring finger
- (0, 17), (17, 18), (18, 19), (19, 20) # Pinky finger
-]
-
-def draw_keypoints_and_skeleton(image, keypoints_data, skeleton_connections, colors_config, confidence_threshold=0.1, point_radius=3, line_thickness=2, is_face=False, is_hand=False):
- if not keypoints_data:
- return
-
- tri_tuples = []
- if isinstance(keypoints_data, list) and len(keypoints_data) > 0 and isinstance(keypoints_data[0], (int, float)) and len(keypoints_data) % 3 == 0:
- for i in range(0, len(keypoints_data), 3):
- tri_tuples.append(keypoints_data[i:i+3])
- else:
- dprint(f"draw_keypoints_and_skeleton: Unexpected keypoints_data format or length not divisible by 3. Data: {keypoints_data}")
- return
-
- if skeleton_connections:
- for i, (joint_idx_a, joint_idx_b) in enumerate(skeleton_connections):
- if joint_idx_a >= len(tri_tuples) or joint_idx_b >= len(tri_tuples):
- continue
- a_x, a_y, a_confidence = tri_tuples[joint_idx_a]
- b_x, b_y, b_confidence = tri_tuples[joint_idx_b]
-
- if a_confidence >= confidence_threshold and b_confidence >= confidence_threshold:
- limb_color = None
- if is_hand:
- limb_color_list = colors_config['limbs']
- limb_color = limb_color_list[i % len(limb_color_list)]
- else:
- limb_color_list = colors_config if isinstance(colors_config, list) else [colors_config]
- limb_color = limb_color_list[i % len(limb_color_list)]
- if limb_color is not None:
- cv2.line(image, (int(a_x), int(a_y)), (int(b_x), int(b_y)), limb_color, line_thickness)
-
- for i, (x, y, confidence) in enumerate(tri_tuples):
- if confidence >= confidence_threshold:
- point_color = None
- current_radius = point_radius
- if is_hand:
- point_color = colors_config['points']
- elif is_face:
- point_color = colors_config
- current_radius = 2
- else:
- point_color_list = colors_config
- point_color = point_color_list[i % len(point_color_list)]
- if point_color is not None:
- cv2.circle(image, (int(x), int(y)), current_radius, point_color, -1)
-
-def gen_skeleton_with_face_hands(pose_keypoints_2d, face_keypoints_2d, hand_left_keypoints_2d, hand_right_keypoints_2d,
- canvas_width, canvas_height, landmarkType, confidence_threshold=0.1):
- image = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
-
- def scale_keypoints(keypoints, target_w, target_h, input_is_normalized):
- if not keypoints: return []
- scaled = []
- if not isinstance(keypoints, list) or (keypoints and not isinstance(keypoints[0], (int, float))):
- dprint(f"scale_keypoints: Unexpected keypoints format: {type(keypoints)}. Expecting flat list of numbers.")
- return []
-
- for i in range(0, len(keypoints), 3):
- x, y, conf = keypoints[i:i+3]
- if input_is_normalized: scaled.extend([x * target_w, y * target_h, conf])
- else: scaled.extend([x, y, conf])
- return scaled
-
- input_is_normalized = (landmarkType == "OpenPose") # This might need adjustment based on actual landmarkType usage
-
- scaled_pose = scale_keypoints(pose_keypoints_2d, canvas_width, canvas_height, input_is_normalized)
- scaled_face = scale_keypoints(face_keypoints_2d, canvas_width, canvas_height, input_is_normalized)
- scaled_hand_left = scale_keypoints(hand_left_keypoints_2d, canvas_width, canvas_height, input_is_normalized)
- scaled_hand_right = scale_keypoints(hand_right_keypoints_2d, canvas_width, canvas_height, input_is_normalized)
-
- draw_keypoints_and_skeleton(image, scaled_pose, body_skeleton, body_colors, confidence_threshold, point_radius=6, line_thickness=4)
- if scaled_face:
- draw_keypoints_and_skeleton(image, scaled_face, face_skeleton, face_color, confidence_threshold, point_radius=2, line_thickness=1, is_face=True)
- hand_colors_config = {'limbs': hand_limb_colors, 'points': hand_keypoint_color}
- if scaled_hand_left:
- draw_keypoints_and_skeleton(image, scaled_hand_left, hand_skeleton, hand_colors_config, confidence_threshold, point_radius=3, line_thickness=2, is_hand=True)
- if scaled_hand_right:
- draw_keypoints_and_skeleton(image, scaled_hand_right, hand_skeleton, hand_colors_config, confidence_threshold, point_radius=3, line_thickness=2, is_hand=True)
- return image
-
-def transform_all_keypoints(keypoints_1_dict, keypoints_2_dict, frames, interpolation="linear"):
- def interpolate_keypoint_set(kp1_list, kp2_list, num_frames, interp_method):
- if not kp1_list and not kp2_list: return [[] for _ in range(num_frames)]
-
- len1 = len(kp1_list) if kp1_list else 0
- len2 = len(kp2_list) if kp2_list else 0
-
- if not kp1_list: kp1_list = [0.0] * len2
- if not kp2_list: kp2_list = [0.0] * len1
-
- if len(kp1_list) != len(kp2_list) or not kp1_list or len(kp1_list) % 3 != 0:
- dprint(f"interpolate_keypoint_set: Mismatched, empty, or non-triplet keypoint lists after padding. KP1 len: {len(kp1_list)}, KP2 len: {len(kp2_list)}. Returning empty sequences.")
- return [[] for _ in range(num_frames)]
-
- tri_tuples_1 = [kp1_list[i:i + 3] for i in range(0, len(kp1_list), 3)]
- tri_tuples_2 = [kp2_list[i:i + 3] for i in range(0, len(kp2_list), 3)]
-
- keypoints_sequence = []
- for j in range(num_frames):
- interpolated_kps_for_frame = []
- t = j / float(num_frames - 1) if num_frames > 1 else 0.0
-
- interp_factor = t
- if interp_method == "ease-in": interp_factor = t * t
- elif interp_method == "ease-out": interp_factor = 1 - (1 - t) * (1 - t)
- elif interp_method == "ease-in-out":
- if t < 0.5: interp_factor = 2 * t * t
- else: interp_factor = 1 - pow(-2 * t + 2, 2) / 2
-
- for i in range(len(tri_tuples_1)):
- x1, y1, c1 = tri_tuples_1[i]
- x2, y2, c2 = tri_tuples_2[i]
- new_x, new_y, new_c = 0.0, 0.0, 0.0
-
- if c1 > 0.05 and c2 > 0.05:
- new_x = x1 + (x2 - x1) * interp_factor
- new_y = y1 + (y2 - y1) * interp_factor
- new_c = c1 + (c2 - c1) * interp_factor
- elif c1 > 0.05 and c2 <= 0.05:
- new_x, new_y = x1, y1
- new_c = c1 * (1.0 - interp_factor)
- elif c1 <= 0.05 and c2 > 0.05:
- new_x, new_y = x2, y2
- new_c = c2 * interp_factor
- interpolated_kps_for_frame.extend([new_x, new_y, new_c])
- keypoints_sequence.append(interpolated_kps_for_frame)
- return keypoints_sequence
-
- pose_1 = keypoints_1_dict.get('pose_keypoints_2d', [])
- face_1 = keypoints_1_dict.get('face_keypoints_2d', [])
- hand_left_1 = keypoints_1_dict.get('hand_left_keypoints_2d', [])
- hand_right_1 = keypoints_1_dict.get('hand_right_keypoints_2d', [])
-
- pose_2 = keypoints_2_dict.get('pose_keypoints_2d', [])
- face_2 = keypoints_2_dict.get('face_keypoints_2d', [])
- hand_left_2 = keypoints_2_dict.get('hand_left_keypoints_2d', [])
- hand_right_2 = keypoints_2_dict.get('hand_right_keypoints_2d', [])
-
- pose_sequence = interpolate_keypoint_set(pose_1, pose_2, frames, interpolation)
- face_sequence = interpolate_keypoint_set(face_1, face_2, frames, interpolation)
- hand_left_sequence = interpolate_keypoint_set(hand_left_1, hand_left_2, frames, interpolation)
- hand_right_sequence = interpolate_keypoint_set(hand_right_1, hand_right_2, frames, interpolation)
-
- combined_sequence = []
- for i in range(frames):
- combined_frame_data = {
- 'pose_keypoints_2d': pose_sequence[i] if i < len(pose_sequence) else [],
- 'face_keypoints_2d': face_sequence[i] if i < len(face_sequence) else [],
- 'hand_left_keypoints_2d': hand_left_sequence[i] if i < len(hand_left_sequence) else [],
- 'hand_right_keypoints_2d': hand_right_sequence[i] if i < len(hand_right_sequence) else []
- }
- combined_sequence.append(combined_frame_data)
- return combined_sequence
-
-def extract_pose_keypoints(image_path: str | Path, include_face=True, include_hands=True, resolution: tuple[int,int]=(640,480)) -> dict:
- # import mediapipe as mp # Already imported at top
- # import cv2 # Already imported at top
- image = cv2.imread(str(image_path))
- if image is None:
- raise ValueError(f"Failed to load image: {image_path}")
-
- # Resize image before processing if its resolution differs significantly,
- # or ensure MediaPipe processes at a consistent internal resolution.
- # For now, assume MediaPipe handles various input sizes well, and keypoints are normalized.
- # The passed 'resolution' param will be used to scale normalized keypoints back to absolute.
-
- height, width = resolution[1], resolution[0] # For scaling output coords
-
- mp_holistic = mp.solutions.holistic
- # It's good practice to use a try-finally for resources like MediaPipe Holistic
- holistic_instance = mp_holistic.Holistic(static_image_mode=True,
- min_detection_confidence=0.5,
- min_tracking_confidence=0.5)
- try:
- # Convert BGR image to RGB for MediaPipe
- results = holistic_instance.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
- finally:
- holistic_instance.close()
-
- keypoints = {}
- pose_kps = []
- if results.pose_landmarks:
- for lm in results.pose_landmarks.landmark:
- # lm.x, lm.y are normalized coordinates; scale them to target resolution
- pose_kps.extend([lm.x * width, lm.y * height, lm.visibility])
- keypoints['pose_keypoints_2d'] = pose_kps
-
- face_kps = []
- if include_face and results.face_landmarks:
- for lm in results.face_landmarks.landmark:
- face_kps.extend([lm.x * width, lm.y * height, lm.visibility if hasattr(lm, 'visibility') else 1.0]) # Some face landmarks might not have visibility
- keypoints['face_keypoints_2d'] = face_kps
-
- left_hand_kps = []
- if include_hands and results.left_hand_landmarks:
- for lm in results.left_hand_landmarks.landmark:
- left_hand_kps.extend([lm.x * width, lm.y * height, lm.visibility])
- keypoints['hand_left_keypoints_2d'] = left_hand_kps
-
- right_hand_kps = []
- if include_hands and results.right_hand_landmarks:
- for lm in results.right_hand_landmarks.landmark:
- right_hand_kps.extend([lm.x * width, lm.y * height, lm.visibility])
- keypoints['hand_right_keypoints_2d'] = right_hand_kps
-
- return keypoints
-
-def create_pose_interpolated_guide_video(output_video_path: str | Path, resolution: tuple[int, int], total_frames: int,
- start_image_path: str | Path, end_image_path: str | Path,
- interpolation="linear", confidence_threshold=0.1,
- include_face=True, include_hands=True, fps=25):
- dprint(f"Creating pose interpolated guide: {output_video_path} from '{Path(start_image_path).name}' to '{Path(end_image_path).name}' ({total_frames} frames). First frame will be actual start image.")
-
- if total_frames <= 0:
- dprint(f"Video creation skipped for {output_video_path} as total_frames is {total_frames}.")
- return
-
- frames_list = []
- canvas_width, canvas_height = resolution
-
- first_visual_frame_np = image_to_frame(start_image_path, resolution)
- if first_visual_frame_np is None:
- print(f"Error loading start image {start_image_path} for guide video frame 0. Using black frame.")
- traceback.print_exc()
- first_visual_frame_np = create_color_frame(resolution, (0,0,0))
- frames_list.append(first_visual_frame_np)
-
- if total_frames > 1:
- try:
- # Pass the target resolution for keypoint scaling
- keypoints_from = extract_pose_keypoints(start_image_path, include_face, include_hands, resolution)
- keypoints_to = extract_pose_keypoints(end_image_path, include_face, include_hands, resolution)
- except Exception as e_extract:
- print(f"Error extracting keypoints for pose interpolation: {e_extract}. Filling remaining guide frames with black.")
- traceback.print_exc()
- black_frame = create_color_frame(resolution, (0,0,0))
- for _ in range(total_frames - 1):
- frames_list.append(black_frame)
- create_video_from_frames_list(frames_list, output_video_path, fps, resolution)
- return
-
- interpolated_sequence = transform_all_keypoints(keypoints_from, keypoints_to, total_frames, interpolation)
-
- # landmarkType for gen_skeleton_with_face_hands should indicate absolute coordinates
- # as extract_pose_keypoints now returns absolute coordinates scaled to 'resolution'
- landmark_type_for_gen = "AbsoluteCoords"
-
- for i in range(1, total_frames):
- if i < len(interpolated_sequence):
- frame_data = interpolated_sequence[i]
- pose_kps = frame_data.get('pose_keypoints_2d', [])
- face_kps = frame_data.get('face_keypoints_2d', []) if include_face else []
- hand_left_kps = frame_data.get('hand_left_keypoints_2d', []) if include_hands else []
- hand_right_kps = frame_data.get('hand_right_keypoints_2d', []) if include_hands else []
-
- img = gen_skeleton_with_face_hands(
- pose_kps, face_kps, hand_left_kps, hand_right_kps,
- canvas_width, canvas_height,
- landmark_type_for_gen, # Keypoints are already absolute
- confidence_threshold
- )
- frames_list.append(img)
- else:
- dprint(f"Warning: Interpolated sequence too short at index {i} for {output_video_path}. Appending black frame.")
- frames_list.append(create_color_frame(resolution, (0,0,0)))
-
- if len(frames_list) != total_frames:
- dprint(f"Warning: Generated {len(frames_list)} frames for {output_video_path}, expected {total_frames}. Adjusting.")
- if len(frames_list) < total_frames:
- last_frame = frames_list[-1] if frames_list else create_color_frame(resolution, (0,0,0))
- frames_list.extend([last_frame.copy() for _ in range(total_frames - len(frames_list))])
- else:
- frames_list = frames_list[:total_frames]
-
- if not frames_list:
- dprint(f"Error: No frames for video {output_video_path}. Skipping creation.")
- return
-
- create_video_from_frames_list(frames_list, output_video_path, fps, resolution)
-
-# --- Debug Summary Video Helpers ---
-def get_resized_frame(video_path_str: str, target_size: tuple[int, int], frame_ratio: float = 0.5) -> np.ndarray | None:
- """Extracts a frame (by ratio, e.g., 0.5 for middle) from a video and resizes it."""
- video_path = Path(video_path_str)
- if not video_path.exists() or video_path.stat().st_size == 0:
- dprint(f"GET_RESIZED_FRAME: Video not found or empty: {video_path_str}")
- placeholder = create_color_frame(target_size, (10, 10, 10)) # Dark grey
- cv2.putText(placeholder, "Not Found", (10, target_size[1] // 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
- return placeholder
-
- cap = None
- try:
- cap = cv2.VideoCapture(str(video_path))
- if not cap.isOpened():
- dprint(f"GET_RESIZED_FRAME: Could not open video: {video_path_str}")
- return create_color_frame(target_size, (20,20,20))
-
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
- if total_frames == 0:
- dprint(f"GET_RESIZED_FRAME: Video has 0 frames: {video_path_str}")
- return create_color_frame(target_size, (30,30,30))
-
- frame_to_get = int(total_frames * frame_ratio)
- frame_to_get = max(0, min(frame_to_get, total_frames - 1)) # Clamp
-
- cap.set(cv2.CAP_PROP_POS_FRAMES, float(frame_to_get))
- ret, frame = cap.read()
- if not ret or frame is None:
- dprint(f"GET_RESIZED_FRAME: Could not read frame {frame_to_get} from: {video_path_str}")
- return create_color_frame(target_size, (40,40,40))
-
- return cv2.resize(frame, target_size, interpolation=cv2.INTER_AREA)
- except Exception as e:
- dprint(f"GET_RESIZED_FRAME: Exception processing {video_path_str}: {e}")
- return create_color_frame(target_size, (50,50,50)) # Error color
- finally:
- if cap: cap.release()
-
-def draw_multiline_text(image, text_lines, start_pos, font, font_scale, color, thickness, line_spacing):
- x, y = start_pos
- for i, line in enumerate(text_lines):
- line_y = y + (i * (cv2.getTextSize(line, font, font_scale, thickness)[0][1] + line_spacing))
- cv2.putText(image, line, (x, line_y), font, font_scale, color, thickness, cv2.LINE_AA)
- return image
-
-def generate_debug_summary_video(segments_data: list[dict], output_path: str | Path, fps: int,
- num_frames_for_collage: int,
- target_thumb_size: tuple[int, int] = (320, 180)):
- if not DEBUG_MODE: return # Only run if debug mode is on
- if not segments_data:
- dprint("GENERATE_DEBUG_SUMMARY_VIDEO: No segment data provided.")
- return
-
- dprint(f"Generating animated debug collage with {num_frames_for_collage} frames, at {fps} FPS.")
-
- thumb_w, thumb_h = target_thumb_size
- padding = 10
- header_h = 50
- text_line_h_approx = 20
- max_settings_lines = 6
- settings_area_h = (text_line_h_approx * max_settings_lines) + padding
-
- num_segments = len(segments_data)
- col_w = thumb_w + (2 * padding)
- canvas_w = num_segments * col_w
- canvas_h = header_h + (thumb_h * 2) + (padding * 3) + settings_area_h + padding
-
- font = cv2.FONT_HERSHEY_SIMPLEX
- font_scale_small = 0.4
- font_scale_title = 0.6
- text_color = (230, 230, 230)
- title_color = (255, 255, 255)
- line_thickness = 1
-
- overall_static_template_canvas = np.full((canvas_h, canvas_w, 3), (30, 30, 30), dtype=np.uint8)
- for idx, seg_data in enumerate(segments_data):
- col_x_start = idx * col_w
- center_x_col = col_x_start + col_w // 2
- title_text = f"Segment {seg_data['segment_index']}"
- (tw, th), _ = cv2.getTextSize(title_text, font, font_scale_title, line_thickness)
- cv2.putText(overall_static_template_canvas, title_text, (center_x_col - tw//2, header_h - padding), font, font_scale_title, title_color, line_thickness, cv2.LINE_AA)
-
- y_offset = header_h
- cv2.putText(overall_static_template_canvas, "Input Guide", (col_x_start + padding, y_offset + text_line_h_approx), font, font_scale_small, text_color, line_thickness)
- y_offset += thumb_h + padding
- cv2.putText(overall_static_template_canvas, "Headless Output", (col_x_start + padding, y_offset + text_line_h_approx), font, font_scale_small, text_color, line_thickness)
- y_offset += thumb_h + padding
-
- settings_y_start = y_offset
- cv2.putText(overall_static_template_canvas, "Settings:", (col_x_start + padding, settings_y_start + text_line_h_approx), font, font_scale_small, text_color, line_thickness)
- settings_text_lines = []
- payload = seg_data.get("task_payload", {})
- settings_text_lines.append(f"Task ID: {payload.get('task_id', 'N/A')[:10]}...")
- prompt_short = payload.get('prompt', 'N/A')[:35] + ("..." if len(payload.get('prompt', '')) > 35 else "")
- settings_text_lines.append(f"Prompt: {prompt_short}")
- settings_text_lines.append(f"Seed: {payload.get('seed', 'N/A')}, Frames: {payload.get('frames', 'N/A')}")
- settings_text_lines.append(f"Resolution: {payload.get('resolution', 'N/A')}")
- draw_multiline_text(overall_static_template_canvas, settings_text_lines[:max_settings_lines],
- (col_x_start + padding, settings_y_start + text_line_h_approx + padding),
- font, font_scale_small, text_color, line_thickness, 5)
-
- error_placeholder_frame = create_color_frame(target_thumb_size, (50, 0, 0))
- cv2.putText(error_placeholder_frame, "ERR", (10, target_thumb_size[1]//2), font, 0.8, (255,255,255), 1)
- not_found_placeholder_frame = create_color_frame(target_thumb_size, (0, 50, 0))
- cv2.putText(not_found_placeholder_frame, "N/A", (10, target_thumb_size[1]//2), font, 0.8, (255,255,255), 1)
- static_thumbs_cache = {}
- for seg_idx_cache, seg_data_cache in enumerate(segments_data):
- guide_thumb = get_resized_frame(seg_data_cache["guide_video_path"], target_thumb_size, frame_ratio=0.5)
- output_thumb = get_resized_frame(seg_data_cache["raw_headless_output_path"], target_thumb_size, frame_ratio=0.5)
-
- static_thumbs_cache[seg_idx_cache] = {
- 'guide': guide_thumb if guide_thumb is not None else not_found_placeholder_frame,
- 'output': output_thumb if output_thumb is not None else not_found_placeholder_frame
- }
-
- writer = None
- try:
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
- writer = cv2.VideoWriter(str(output_path), fourcc, float(fps), (canvas_w, canvas_h))
- if not writer.isOpened():
- dprint(f"GENERATE_DEBUG_SUMMARY_VIDEO: Failed to open VideoWriter for {output_path}")
- return
-
- dprint(f"GENERATE_DEBUG_SUMMARY_VIDEO: Writing sequentially animated collage to {output_path}")
-
- for active_seg_idx in range(num_segments):
- dprint(f"Animating segment {active_seg_idx} in collage...")
- caps_for_active_segment = {'guide': None, 'output': None, 'last_frames': {}}
- video_paths_to_load = {
- 'guide': segments_data[active_seg_idx]["guide_video_path"],
- 'output': segments_data[active_seg_idx]["raw_headless_output_path"]
- }
- for key, path_str in video_paths_to_load.items():
- p = Path(path_str)
- if p.exists() and p.stat().st_size > 0:
- cap_video = cv2.VideoCapture(str(p))
- if cap_video.isOpened():
- caps_for_active_segment[key] = cap_video
- ret, frame = cap_video.read();
- caps_for_active_segment['last_frames'][key] = cv2.resize(frame, target_thumb_size, cv2.INTER_AREA) if ret and frame is not None else error_placeholder_frame
- cap_video.set(cv2.CAP_PROP_POS_FRAMES, 0.0)
- else: caps_for_active_segment['last_frames'][key] = error_placeholder_frame
- else: caps_for_active_segment['last_frames'][key] = not_found_placeholder_frame
-
- for frame_num in range(num_frames_for_collage):
- current_frame_canvas = overall_static_template_canvas.copy()
-
- for display_seg_idx in range(num_segments):
- col_x_start = display_seg_idx * col_w
- current_y_pos = header_h
-
- videos_to_composite = [None, None] # guide, output
-
- if display_seg_idx == active_seg_idx:
- if caps_for_active_segment['guide']:
- ret, frame = caps_for_active_segment['guide'].read()
- if ret and frame is not None: videos_to_composite[0] = cv2.resize(frame, target_thumb_size); caps_for_active_segment['last_frames']['guide'] = videos_to_composite[0]
- else: videos_to_composite[0] = caps_for_active_segment['last_frames'].get('guide', error_placeholder_frame)
- else: videos_to_composite[0] = caps_for_active_segment['last_frames'].get('guide', not_found_placeholder_frame)
- if caps_for_active_segment['output']:
- ret, frame = caps_for_active_segment['output'].read()
- if ret and frame is not None: videos_to_composite[1] = cv2.resize(frame, target_thumb_size); caps_for_active_segment['last_frames']['output'] = videos_to_composite[1]
- else: videos_to_composite[1] = caps_for_active_segment['last_frames'].get('output', error_placeholder_frame)
- else: videos_to_composite[1] = caps_for_active_segment['last_frames'].get('output', not_found_placeholder_frame)
- else:
- videos_to_composite[0] = static_thumbs_cache[display_seg_idx]['guide']
- videos_to_composite[1] = static_thumbs_cache[display_seg_idx]['output']
-
- current_frame_canvas[current_y_pos : current_y_pos + thumb_h, col_x_start + padding : col_x_start + padding + thumb_w] = videos_to_composite[0]
- current_y_pos += thumb_h + padding
- current_frame_canvas[current_y_pos : current_y_pos + thumb_h, col_x_start + padding : col_x_start + padding + thumb_w] = videos_to_composite[1]
-
- writer.write(current_frame_canvas)
-
- if caps_for_active_segment['guide']: caps_for_active_segment['guide'].release()
- if caps_for_active_segment['output']: caps_for_active_segment['output'].release()
- dprint(f"Finished animating segment {active_seg_idx} in collage.")
-
- dprint(f"GENERATE_DEBUG_SUMMARY_VIDEO: Finished writing sequentially animated debug collage.")
-
- except Exception as e:
- dprint(f"GENERATE_DEBUG_SUMMARY_VIDEO: Exception during video writing: {e} - {traceback.format_exc()}")
- finally:
- if writer: writer.release()
- dprint("GENERATE_DEBUG_SUMMARY_VIDEO: Video writer released.")
-
-
-def generate_different_pose_debug_video_summary(
- video_stage_data: list[dict],
- output_path: Path,
- fps: int,
- target_resolution: tuple[int, int] # width, height
-):
- if not DEBUG_MODE: return # Only run if debug mode is on
- dprint(f"Generating different_pose DEBUG VIDEO summary at {output_path} ({fps} FPS, {target_resolution[0]}x{target_resolution[1]})")
-
- all_output_frames = []
- font_pil = None
- try:
- pil_font = ImageFont.truetype("arial.ttf", size=24)
- except IOError:
- pil_font = ImageFont.load_default()
- dprint("Arial font not found for debug video summary, using default PIL font.")
-
- text_color = (255, 255, 255)
- bg_color = (0,0,0)
- text_bg_opacity = 128
-
- for stage_info in video_stage_data:
- label = stage_info.get('label', 'Unknown Stage')
- file_type = stage_info.get('type', 'image')
- file_path_str = stage_info.get('path')
- display_duration_frames = stage_info.get('display_frames', fps * 2)
-
- if not file_path_str:
- print(f"[Debug Video] Missing path for stage '{label}', skipping.")
- continue
-
- file_path = Path(file_path_str)
- if not file_path.exists():
- print(f"[Debug Video] File not found for stage '{label}': {file_path}, creating placeholder frames.")
- placeholder_frame_np = create_color_frame(target_resolution, (50, 0, 0))
- placeholder_pil = Image.fromarray(cv2.cvtColor(placeholder_frame_np, cv2.COLOR_BGR2RGB))
- draw = ImageDraw.Draw(placeholder_pil)
- draw.text((20, 20), f"{label}\n(File Not Found)", font=pil_font, fill=text_color)
- placeholder_frame_final_np = cv2.cvtColor(np.array(placeholder_pil), cv2.COLOR_RGB2BGR)
- all_output_frames.extend([placeholder_frame_final_np] * display_duration_frames)
- continue
-
- current_stage_frames_np = []
- try:
- if file_type == 'image':
- pil_img = Image.open(file_path).convert("RGB")
- pil_img_resized = pil_img.resize(target_resolution, Image.Resampling.LANCZOS)
- np_bgr_frame = cv2.cvtColor(np.array(pil_img_resized), cv2.COLOR_RGB2BGR)
- current_stage_frames_np = [np_bgr_frame] * display_duration_frames
-
- elif file_type == 'video':
- cap_video = cv2.VideoCapture(str(file_path))
- if not cap_video.isOpened():
- raise IOError(f"Could not open video: {file_path}")
-
- frames_read = 0
- while frames_read < display_duration_frames:
- ret, frame_np = cap_video.read()
- if not ret:
- if current_stage_frames_np:
- current_stage_frames_np.extend([current_stage_frames_np[-1]] * (display_duration_frames - frames_read))
- else:
- err_frame = create_color_frame(target_resolution, (0,50,0))
- err_pil = Image.fromarray(cv2.cvtColor(err_frame, cv2.COLOR_BGR2RGB))
- ImageDraw.Draw(err_pil).text((20,20), f"{label}\n(Video Read Error)", font=pil_font, fill=text_color)
- current_stage_frames_np.extend([cv2.cvtColor(np.array(err_pil), cv2.COLOR_RGB2BGR)] * (display_duration_frames - frames_read))
- break
-
- if frame_np.shape[1] != target_resolution[0] or frame_np.shape[0] != target_resolution[1]:
- frame_np = cv2.resize(frame_np, target_resolution, interpolation=cv2.INTER_AREA)
- current_stage_frames_np.append(frame_np)
- frames_read += 1
- cap_video.release()
- else:
- print(f"[Debug Video] Unknown file type '{file_type}' for stage '{label}'. Skipping.")
- continue
-
- for i in range(len(current_stage_frames_np)):
- frame_pil = Image.fromarray(cv2.cvtColor(current_stage_frames_np[i], cv2.COLOR_BGR2RGB))
- draw = ImageDraw.Draw(frame_pil, 'RGBA')
-
- text_x, text_y = 20, 20
- bbox = draw.textbbox((text_x, text_y), label, font=pil_font)
- rect_coords = [(bbox[0]-5, bbox[1]-5), (bbox[2]+5, bbox[3]+5)]
- draw.rectangle(rect_coords, fill=(bg_color[0], bg_color[1], bg_color[2], text_bg_opacity))
- draw.text((text_x, text_y), label, font=pil_font, fill=text_color)
-
- current_stage_frames_np[i] = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR)
-
- all_output_frames.extend(current_stage_frames_np)
-
- except Exception as e_stage:
- print(f"[Debug Video] Error processing stage '{label}' (path: {file_path}): {e_stage}")
- traceback.print_exc()
- err_frame_np = create_color_frame(target_resolution, (0,0,50))
- err_pil = Image.fromarray(cv2.cvtColor(err_frame_np, cv2.COLOR_BGR2RGB))
- ImageDraw.Draw(err_pil).text((20,20), f"{label}\n(Stage Processing Error)", font=pil_font, fill=text_color)
- all_output_frames.extend([cv2.cvtColor(np.array(err_pil), cv2.COLOR_RGB2BGR)] * display_duration_frames)
-
- if not all_output_frames:
- dprint("[Debug Video] No frames were generated for the debug video summary.")
- return
-
- print(f"[Debug Video] Creating final video with {len(all_output_frames)} frames.")
- create_video_from_frames_list(all_output_frames, output_path, fps, target_resolution)
- print(f"Debug video summary for 'different_pose' saved to: {output_path.resolve()}")
\ No newline at end of file
diff --git a/sm_functions/different_pose.py b/sm_functions/different_pose.py
deleted file mode 100644
index a4871f862..000000000
--- a/sm_functions/different_pose.py
+++ /dev/null
@@ -1,313 +0,0 @@
-"""Different-pose task handler."""
-
-import json
-import shutil
-import traceback
-from pathlib import Path
-
-import cv2 # pip install opencv-python
-from PIL import Image # pip install Pillow
-
-# Import from the new common_utils module
-from .common_utils import (
- DEBUG_MODE, dprint, generate_unique_task_id, add_task_to_db, poll_task_status,
- save_frame_from_video, create_pose_interpolated_guide_video,
- generate_different_pose_debug_video_summary # For debug mode
-)
-
-def run_different_pose_task(task_args, common_args, parsed_resolution, main_output_dir, db_file_path, executed_command_str: str | None = None):
- print("--- Running Task: Different Pose (Modified Workflow: User Input + OpenPose T2I) ---")
- dprint(f"Task Args: {task_args}")
- dprint(f"Common Args: {common_args}")
-
- num_target_frames = common_args.output_video_frames
- if num_target_frames < 2:
- print(f"Error: --output_video_frames (set to {num_target_frames}) must be at least 2.")
- return 1
-
- user_input_image_path = Path(task_args.input_image).resolve()
- if not user_input_image_path.exists():
- print(f"Error: Input image not found: {user_input_image_path}")
- return 1
-
- task_run_id = generate_unique_task_id("sm_pose_rife_")
- task_work_dir = main_output_dir / f"different_pose_run_{task_run_id}"
- task_work_dir.mkdir(parents=True, exist_ok=True)
- print(f"Working directory for this 'different_pose' run: {task_work_dir.resolve()}")
-
- if DEBUG_MODE and executed_command_str:
- try:
- command_file_path = task_work_dir / "executed_command.txt"
- with open(command_file_path, "w") as f:
- f.write(executed_command_str)
- dprint(f"Saved executed command to {command_file_path}")
- except Exception as e_cmd_save:
- dprint(f"Warning: Could not save executed command: {e_cmd_save}")
-
- debug_video_stages_data = []
- default_image_display_frames = common_args.fps_helpers * 2
-
- if DEBUG_MODE:
- debug_video_stages_data.append({
- 'label': f"0. User Input ({user_input_image_path.name})",
- 'type': 'image',
- 'path': str(user_input_image_path),
- 'display_frames': default_image_display_frames
- })
-
- # --- Step 1: Generate OpenPose from User Input Image ---
- print("\nStep 1: Generating OpenPose from user input image...")
- openpose_user_input_task_id = generate_unique_task_id("pose_op_user_")
- openpose_user_input_output_path = task_work_dir / f"{openpose_user_input_task_id}_openpose_from_user.png"
-
- openpose_user_input_payload = {
- "task_id": openpose_user_input_task_id,
- "task_type": "generate_openpose",
- "model": "core_pose_processor",
- "input_image_path": str(user_input_image_path.resolve()),
- "output_path": str(openpose_user_input_output_path.resolve())
- }
- dprint(f"OpenPose task payload for user input: {json.dumps(openpose_user_input_payload, indent=2)}")
- try:
- add_task_to_db(openpose_user_input_payload, db_file_path, "generate_openpose")
- openpose_user_input_path_str = poll_task_status(openpose_user_input_task_id, db_file_path, common_args.poll_interval, common_args.poll_timeout)
- except Exception as e_cmd_save:
- dprint(f"Warning: Could not save executed command: {e_cmd_save}")
-
- if not openpose_user_input_path_str:
- print("Failed to generate OpenPose from user input image.")
- return 1
- openpose_user_input_path = Path(openpose_user_input_path_str)
- print(f"Successfully generated OpenPose from user input image: {openpose_user_input_path}")
- if DEBUG_MODE:
- debug_video_stages_data.append({
- 'label': f"1. OpenPose from User Input ({openpose_user_input_path.name})",
- 'type': 'image',
- 'path': str(openpose_user_input_path),
- 'display_frames': default_image_display_frames
- })
-
- # --- Step 2: Initial Image Generation (Text-to-Image for target pose reference) ---
- print("\nStep 2: Generating target image from prompt (Text-to-Image)...")
- t2i_target_task_id = generate_unique_task_id("pose_t2i_target_")
- t2i_target_guide_video_path = task_work_dir / f"{t2i_target_task_id}_guide.mp4"
- t2i_target_headless_output_path = task_work_dir / f"{t2i_target_task_id}_video_raw.mp4"
- generated_image_for_target_pose_path = task_work_dir / f"{t2i_target_task_id}_generated_from_prompt.png"
-
- temp_neutral_guide_image_path_s2 = task_work_dir / "_temp_neutral_guide_frame_s2.png"
- try:
- neutral_pil_image_s2 = Image.new('RGB', (parsed_resolution[0] // 16, parsed_resolution[1] // 16), (128, 128, 128))
- neutral_pil_image_s2.save(temp_neutral_guide_image_path_s2)
- create_pose_interpolated_guide_video(
- output_video_path=t2i_target_guide_video_path,
- resolution=parsed_resolution,
- total_frames=1, # Single frame for T2I guide
- start_image_path=temp_neutral_guide_image_path_s2, # Use neutral image for both
- end_image_path=temp_neutral_guide_image_path_s2,
- fps=common_args.fps_helpers,
- confidence_threshold=0.01,
- include_face=False,
- include_hands=False
- )
- except Exception as e:
- print(f"Error creating neutral guide video for T2I step: {e}")
- if temp_neutral_guide_image_path_s2.exists(): temp_neutral_guide_image_path_s2.unlink(missing_ok=True)
- return 1
- finally:
- if temp_neutral_guide_image_path_s2.exists(): temp_neutral_guide_image_path_s2.unlink(missing_ok=True)
-
- t2i_target_payload = {
- "task_id": t2i_target_task_id,
- "prompt": task_args.prompt,
- "model": common_args.model_name,
- "resolution": common_args.resolution,
- "frames": 1,
- "seed": common_args.seed,
- "video_guide_path": str(t2i_target_guide_video_path.resolve()),
- "output_path": str(t2i_target_headless_output_path.resolve())
- }
- if common_args.use_causvid_lora: t2i_target_payload["use_causvid_lora"] = True
- dprint(f"Added lora_name: 'jump' to T2I target payload.")
- dprint(f"Target T2I/I2I task payload: {json.dumps(t2i_target_payload, indent=2)}")
- try:
- add_task_to_db(t2i_target_payload, db_file_path, task_args.task)
- raw_video_from_headless_s2 = poll_task_status(t2i_target_task_id, db_file_path, common_args.poll_interval, common_args.poll_timeout)
- except Exception as e_cmd_save:
- dprint(f"Warning: Could not save executed command: {e_cmd_save}")
-
- if not raw_video_from_headless_s2:
- print("Failed to generate target image from prompt.")
- return 1
- t2i_video_output_path_s2 = Path(raw_video_from_headless_s2)
- if not save_frame_from_video(t2i_video_output_path_s2, 0, generated_image_for_target_pose_path, parsed_resolution):
- print(f"Failed to extract frame from target T2I video: {t2i_video_output_path_s2}")
- return 1
- print(f"Successfully generated target image from prompt: {generated_image_for_target_pose_path}")
- if DEBUG_MODE:
- debug_video_stages_data.append({
- 'label': f"2. Target T2I Image ({generated_image_for_target_pose_path.name})",
- 'type': 'image',
- 'path': str(generated_image_for_target_pose_path),
- 'display_frames': default_image_display_frames
- })
-
- # --- Step 3: Generate OpenPose from the Target T2I Image ---
- print("\nStep 3: Generating OpenPose from the target text-generated image...")
- openpose_t2i_task_id = generate_unique_task_id("pose_op_t2i_")
- openpose_t2i_output_path = task_work_dir / f"{openpose_t2i_task_id}_openpose_from_t2i.png"
-
- openpose_t2i_payload = {
- "task_id": openpose_t2i_task_id,
- "task_type": "generate_openpose",
- "model": "core_pose_processor",
- "input_image_path": str(generated_image_for_target_pose_path.resolve()),
- "output_path": str(openpose_t2i_output_path.resolve())
- }
- dprint(f"OpenPose T2I/I2I task payload: {json.dumps(openpose_t2i_payload, indent=2)}")
- try:
- add_task_to_db(openpose_t2i_payload, db_file_path, "generate_openpose")
- openpose_t2i_path_str = poll_task_status(openpose_t2i_task_id, db_file_path, common_args.poll_interval, common_args.poll_timeout)
- except Exception as e_cmd_save:
- dprint(f"Warning: Could not save executed command: {e_cmd_save}")
-
- if not openpose_t2i_path_str:
- print("Failed to generate OpenPose from T2I image.")
- return 1
- openpose_t2i_path = Path(openpose_t2i_path_str)
- print(f"Successfully generated OpenPose from T2I image: {openpose_t2i_path}")
- if DEBUG_MODE:
- debug_video_stages_data.append({
- 'label': f"3. OpenPose from T2I ({openpose_t2i_path.name})",
- 'type': 'image',
- 'path': str(openpose_t2i_path),
- 'display_frames': default_image_display_frames
- })
-
- print("\nStep 4: RIFE Interpolation SKIPPED as per modification.")
-
- # --- Step 5: Create Final Composite Guide Video ---
- print("\nStep 5: Creating custom guide video (OpenPose T2I only)...")
- custom_guide_video_id = generate_unique_task_id("pose_custom_guide_")
- custom_guide_video_path = task_work_dir / f"{custom_guide_video_id}_custom_guide.mp4"
-
- try:
- create_pose_interpolated_guide_video(
- output_video_path=custom_guide_video_path,
- resolution=parsed_resolution,
- total_frames=num_target_frames,
- start_image_path=user_input_image_path,
- end_image_path=generated_image_for_target_pose_path,
- fps=common_args.fps_helpers,
- confidence_threshold=0.1,
- include_face=True,
- include_hands=True
- )
- print(f"Successfully created pose-interpolated guide video: {custom_guide_video_path}")
- except Exception as e_custom_guide:
- print(f"Error creating custom guide video: {e_custom_guide}")
- traceback.print_exc()
- return 1
-
- final_video_guide_path = custom_guide_video_path
-
- if DEBUG_MODE:
- debug_video_stages_data.append({
- 'label': f"4. Custom Guide Video ({custom_guide_video_path.name})",
- 'type': 'video',
- 'path': str(custom_guide_video_path.resolve()),
- 'display_frames': common_args.fps_helpers * (num_target_frames // common_args.fps_helpers + 1)
- })
-
- # --- Step 6: Final Video Generation ---
- print("\nStep 6: Generating final video using custom guide (OpenPose T2I only) and user input reference...")
- final_video_task_id = generate_unique_task_id("pose_finalvid_custom_")
- final_video_headless_output_path = task_work_dir / f"{final_video_task_id}_video_raw.mp4"
-
- final_video_payload = {
- "task_id": final_video_task_id,
- "prompt": task_args.prompt,
- "model": common_args.model_name,
- "resolution": common_args.resolution,
- "frames": num_target_frames,
- "seed": common_args.seed + 1,
- "video_guide_path": str(final_video_guide_path.resolve()),
- "reference_image_path": str(user_input_image_path.resolve()),
- "image_refs_paths": [str(user_input_image_path.resolve())],
- "image_prompt_type": "IV",
- "output_path": str(final_video_headless_output_path.resolve())
- }
-
- if common_args.use_causvid_lora: final_video_payload["use_causvid_lora"] = True
- dprint(f"Added lora_name: 'jump' to final video payload.")
- dprint(f"Final video generation task payload: {json.dumps(final_video_payload, indent=2)}")
- try:
- add_task_to_db(final_video_payload, db_file_path, task_args.task)
- raw_final_video_from_headless = poll_task_status(final_video_task_id, db_file_path, common_args.poll_interval, common_args.poll_timeout)
- except Exception as e_cmd_save:
- dprint(f"Warning: Could not save executed command: {e_cmd_save}")
-
- if not raw_final_video_from_headless:
- print("Failed to generate final video using custom guide and reference.")
- return 1
- final_posed_video_path = Path(raw_final_video_from_headless)
- print(f"Successfully generated final video with custom guide and reference: {final_posed_video_path}")
-
- print("\nTrimming logic for final video SKIPPED as per modification.")
-
- if DEBUG_MODE:
- debug_video_stages_data.append({
- 'label': f"5. Final Video with Custom Guide ({final_posed_video_path.name})",
- 'type': 'video',
- 'path': str(final_posed_video_path),
- 'display_frames': common_args.fps_helpers * (num_target_frames // common_args.fps_helpers +1)
- })
-
- # --- Step 7: Result Extraction ---
- print(f"\nStep 7: Extracting the final posed image (last frame of {num_target_frames}-frame video)...")
- final_posed_image_output_path = main_output_dir / f"final_posed_image_{task_run_id}.png"
-
- cap_tmp_s7 = cv2.VideoCapture(str(final_posed_video_path))
- total_frames_generated_s7 = int(cap_tmp_s7.get(cv2.CAP_PROP_FRAME_COUNT)) if cap_tmp_s7.isOpened() else 0
- if cap_tmp_s7.isOpened(): cap_tmp_s7.release()
-
- if total_frames_generated_s7 == 0:
- print(f"Error: Final generated video is empty: {final_posed_video_path}")
- return 1
- target_frame_index_s7 = num_target_frames - 1
- if target_frame_index_s7 >= total_frames_generated_s7 :
- print(f"Warning: Target frame index {target_frame_index_s7} is out of bounds for video with {total_frames_generated_s7} frames. Using last available frame.")
- target_frame_index_s7 = max(0, total_frames_generated_s7 - 1)
-
- if not save_frame_from_video(final_posed_video_path, target_frame_index_s7, final_posed_image_output_path, parsed_resolution):
- print(f"Failed to extract final posed image from custom-guided video with reference: {final_posed_video_path}")
- return 1
-
- print(f"\nSuccessfully completed 'different_pose' task with custom (OpenPose T2I only guide + user input reference) workflow!")
- print(f"Final posed image saved to: {final_posed_image_output_path.resolve()}")
-
- if DEBUG_MODE:
- debug_video_stages_data.append({
- 'label': f"6. Final Extracted Image ({final_posed_image_output_path.name})",
- 'type': 'image',
- 'path': str(final_posed_image_output_path),
- 'display_frames': default_image_display_frames
- })
- if debug_video_stages_data:
- video_collage_file_name_s7 = f"debug_video_summary_pose_rife_{task_run_id}.mp4"
- video_collage_output_path_s7 = main_output_dir / video_collage_file_name_s7
- generate_different_pose_debug_video_summary(
- debug_video_stages_data, video_collage_output_path_s7,
- fps=common_args.fps_helpers, target_resolution=parsed_resolution
- )
-
- if not common_args.skip_cleanup and not DEBUG_MODE:
- print(f"\nCleaning up intermediate files in {task_work_dir}...")
- try:
- shutil.rmtree(task_work_dir)
- print(f"Removed intermediate directory: {task_work_dir}")
- except OSError as e_clean:
- print(f"Error removing intermediate directory {task_work_dir}: {e_clean}")
- else:
- print(f"Skipping cleanup of intermediate files in {task_work_dir}.")
-
- return 0
\ No newline at end of file
diff --git a/sm_functions/travel_between_images.py b/sm_functions/travel_between_images.py
deleted file mode 100644
index 6d9548c4f..000000000
--- a/sm_functions/travel_between_images.py
+++ /dev/null
@@ -1,379 +0,0 @@
-"""Travel-between-images task handler."""
-
-import json
-import shutil
-import traceback
-from pathlib import Path
-import datetime
-import subprocess
-# import math # No longer used directly here
-import requests
-from urllib.parse import urlparse
-import uuid
-
-import cv2
-import numpy as np
-from PIL import Image, ImageEnhance
-
-# Import from the common_utils module
-from .common_utils import (
- DEBUG_MODE, dprint, generate_unique_task_id, add_task_to_db, poll_task_status, # poll_task_status will be removed from usage here
- # extract_video_segment_ffmpeg, # Not used directly by orchestrator
- stitch_videos_ffmpeg, # Will be used by a stitcher task in headless
- # create_pose_interpolated_guide_video, # Guide creation moves to headless
- generate_debug_summary_video, # Potentially used by a final debug task in headless
- # extract_specific_frame_ffmpeg, # Used by headless segment task
- # concatenate_videos_ffmpeg, # Alternative stitch, might be used by headless
- get_video_frame_count_and_fps, # Used by headless segment/stitch task
- _get_unique_target_path,
- image_to_frame,
- create_color_frame,
- _adjust_frame_brightness,
- _copy_to_folder_with_unique_name, # For downloading initial video
- _apply_strength_to_image # For preparing VACE refs if done by orchestrator
-)
-
-# Import from the video_utils module (some might be used by headless now)
-from .video_utils import (
- crossfade_ease,
- _blend_linear,
- _blend_linear_sharp,
- cross_fade_overlap_frames,
- extract_frames_from_video,
- create_video_from_frames_list,
- _apply_saturation_to_video_ffmpeg,
- color_match_video_to_reference
-)
-
-DEFAULT_SEGMENT_FRAMES = 81 # This constant might still be relevant for defaults before expansion
-
-
-# --- Easing functions (for guide video timing, not cross-fade) ---
-# These are kept if orchestrator prepares any initial guide data or passes easing params
-def ease_linear(t: float) -> float:
- """Linear interpolation (0..1 -> 0..1)"""
- return t
-
-def ease_in_quad(t: float) -> float:
- """Ease-in quadratic (0..1 -> 0..1)"""
- return t * t
-
-def ease_out_quad(t: float) -> float:
- """Ease-out quadratic (0..1 -> 0..1)"""
- return t * (2 - t)
-
-def ease_in_out_quad(t: float) -> float:
- """Ease-in-out quadratic (0..1 -> 0..1)"""
- return t * t * (3 - 2 * t) if t < 0.5 else (1 - ((-2 * t + 2) * (-2 * t + 2) / 2))
-
-
-def get_easing_function(curve_type: str):
- if curve_type == "ease_in":
- return ease_in_quad
- elif curve_type == "ease_out":
- return ease_out_quad
- elif curve_type == "ease_in_out":
- return ease_in_out_quad
- elif curve_type == "linear":
- return ease_linear
- else: # Default or unknown
- dprint(f"Warning: Unknown curve_type '{curve_type}'. Defaulting to ease_in_out_quad.")
- return ease_in_out_quad
-
-
-# --- Helper: Re-encode a video to H.264 using FFmpeg (ensures consistent codec) ---
-# This might be used by headless if a continued video needs re-encoding.
-def _reencode_to_h264_ffmpeg(
- input_video_path: str | Path,
- output_video_path: str | Path,
- fps: float | None = None,
- resolution: tuple[int, int] | None = None,
- crf: int = 23,
- preset: str = "veryfast"
-):
- """Re-encodes the entire input video to H.264 using libx264."""
- inp = Path(input_video_path)
- outp = Path(output_video_path)
- outp.parent.mkdir(parents=True, exist_ok=True)
-
- cmd = [
- "ffmpeg", "-y",
- "-i", str(inp.resolve()),
- "-an",
- "-c:v", "libx264",
- "-pix_fmt", "yuv420p",
- "-crf", str(crf),
- "-preset", preset,
- ]
- if fps is not None and fps > 0:
- cmd.extend(["-r", str(fps)])
- if resolution is not None:
- w, h = resolution
- cmd.extend(["-vf", f"scale={w}:{h}"])
- cmd.append(str(outp.resolve()))
-
- dprint(f"REENCODE_TO_H264_FFMPEG: Running command: {' '.join(cmd)}")
- try:
- subprocess.run(cmd, check=True, capture_output=True, text=True, encoding="utf-8")
- return outp.exists() and outp.stat().st_size > 0
- except subprocess.CalledProcessError as e:
- print(f"Error during FFmpeg re-encode of {inp} -> {outp}:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}")
- return False
-
-def run_travel_between_images_task(task_args, common_args, parsed_resolution, main_output_dir, db_file_path, executed_command_str: str | None = None):
- print("--- Queuing Task: Travel Between Images (Orchestrator) ---")
- dprint(f"Task Args: {task_args}")
- dprint(f"Common Args: {common_args}")
-
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
- run_id = timestamp # Unique ID for this entire travel operation
-
- # The orchestrator task itself might not need a deep processing folder,
- # as headless will manage folders for its child segment/stitch tasks.
- # However, a shallow folder for the orchestrator's own log/command might be useful.
- orchestrator_log_folder_name = f"travel_orchestrator_log_{run_id}"
- orchestrator_log_folder = main_output_dir / orchestrator_log_folder_name
- orchestrator_log_folder.mkdir(parents=True, exist_ok=True)
- dprint(f"Orchestrator logs/info for this run will be in: {orchestrator_log_folder}")
-
- if DEBUG_MODE and executed_command_str:
- try:
- command_file_path = orchestrator_log_folder / "executed_command.txt"
- with open(command_file_path, "w") as f:
- f.write(executed_command_str)
- dprint(f"Saved executed command to {command_file_path}")
- except Exception as e_cmd_save:
- dprint(f"Warning: Could not save executed command: {e_cmd_save}")
-
- # --- Helper function to download video if URL (used for continue_from_video) ---
- # This needs to run here to get the initial video path for the orchestrator payload.
- def _download_video_if_url(video_url_or_path: str, target_dir: Path, base_name: str) -> Path | None:
- parsed_url = urlparse(video_url_or_path)
- if parsed_url.scheme in ['http', 'https']:
- try:
- dprint(f"Downloading video from URL: {video_url_or_path}")
- response = requests.get(video_url_or_path, stream=True, timeout=300)
- response.raise_for_status()
- original_filename = Path(parsed_url.path).name
- original_suffix = Path(original_filename).suffix if Path(original_filename).suffix else ".mp4"
- if not original_suffix.startswith('.'):
- original_suffix = '.' + original_suffix
- # Use common_utils' _get_unique_target_path
- temp_download_path = _get_unique_target_path(target_dir, base_name, original_suffix)
- with open(temp_download_path, 'wb') as f:
- for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
- dprint(f"Video downloaded successfully to {temp_download_path}")
- return temp_download_path
- except Exception as e_req:
- print(f"Error downloading video {video_url_or_path}: {e_req}")
- return None
- else:
- local_path = Path(video_url_or_path)
- if not local_path.exists():
- print(f"Error: Local video file not found: {local_path}")
- return None
- # Use common_utils' _copy_to_folder_with_unique_name
- copied_local_video_path = _copy_to_folder_with_unique_name(
- source_path=local_path, target_dir=target_dir, base_name=base_name,
- extension=local_path.suffix if local_path.suffix else ".mp4"
- )
- if copied_local_video_path:
- dprint(f"Copied local video {local_path} to {copied_local_video_path}")
- return copied_local_video_path
- else:
- print(f"Error: Failed to copy local video {local_path} to orchestrator folder.")
- return None
- # --- End Download Helper ---
-
- # --- Determine number of segments ---
- if task_args.continue_from_video:
- if not task_args.input_images: # Already validated by steerable_motion.py main
- print("Error: --input_images must be provided with --continue_from_video.")
- return 1 # Should not happen due to prior validation
- num_segments_to_generate = len(task_args.input_images)
- else:
- if len(task_args.input_images) < 2: # Already validated
- print("Error: At least two input images are required without --continue_from_video.")
- return 1 # Should not happen
- num_segments_to_generate = len(task_args.input_images) - 1
-
- # --- Validate num_segments_to_generate ---
- if num_segments_to_generate <= 0:
- print(f"Error: Based on input_images ({len(task_args.input_images)}) and continue_from_video flag, no new segments would be generated.")
- return 1
-
- expanded_base_prompts = task_args.base_prompts * num_segments_to_generate if len(task_args.base_prompts) == 1 else task_args.base_prompts
- expanded_negative_prompts = task_args.negative_prompts * num_segments_to_generate if len(task_args.negative_prompts) == 1 else task_args.negative_prompts
- expanded_segment_frames = task_args.segment_frames * num_segments_to_generate if len(task_args.segment_frames) == 1 else task_args.segment_frames
- expanded_frame_overlap = task_args.frame_overlap * num_segments_to_generate if len(task_args.frame_overlap) == 1 else task_args.frame_overlap
-
- # --- Prepare VACE Image Refs (if strength > 0) ---
- # This logic might still live in the orchestrator if it needs to prepare these files upfront
- # and pass their paths to headless. Or, headless could do it. For now, keeping it here.
- # Paths will be absolute, resolved here.
- vace_image_refs_details = [] # List of dicts: {"type": "initial" or "final", "original_path": str, "processed_path": str, "strength": float}
-
- # Helper to prepare a VACE ref image
- def _prepare_vace_ref_image(original_path_str: str, strength: float, ref_type: str, segment_idx_for_naming: int, target_dir: Path) -> dict | None:
- original_path = Path(original_path_str).resolve()
- clamped_strength = max(0.0, min(1.0, strength))
- if not original_path.exists() or clamped_strength == 0.0:
- dprint(f"VACE Ref ({ref_type} for seg {segment_idx_for_naming}): Original path {original_path} not found or strength {clamped_strength} is 0. Skipping.")
- return None
-
- detail = {"type": ref_type, "original_path": str(original_path), "strength": clamped_strength}
-
- processed_name_base = f"s{segment_idx_for_naming}_{ref_type}_anchor_strength_{clamped_strength:.2f}"
- original_suffix = original_path.suffix if original_path.suffix else ".png"
-
- # Use common_utils
- path_for_processed_ref = _get_unique_target_path(target_dir, processed_name_base, original_suffix)
-
- if abs(clamped_strength - 1.0) < 1e-5: # Strength is 1.0
- try:
- shutil.copy2(str(original_path), str(path_for_processed_ref))
- detail["processed_path"] = str(path_for_processed_ref.resolve())
- dprint(f"VACE Ref ({ref_type} for seg {segment_idx_for_naming}): Copied original due to strength 1.0: {path_for_processed_ref}")
- return detail
- except Exception as e_copy:
- dprint(f"VACE Ref ({ref_type} for seg {segment_idx_for_naming}): Failed to copy original: {e_copy}. Skipping.")
- return None
- else: # Strength < 1.0
- processed_path = _apply_strength_to_image(
- image_path=original_path, strength=clamped_strength,
- output_path=path_for_processed_ref, target_resolution=parsed_resolution
- )
- if processed_path and processed_path.exists():
- detail["processed_path"] = str(processed_path.resolve())
- dprint(f"VACE Ref ({ref_type} for seg {segment_idx_for_naming}): Processed with strength {clamped_strength}: {processed_path}")
- return detail
- else:
- dprint(f"VACE Ref ({ref_type} for seg {segment_idx_for_naming}): Failed to process with strength {clamped_strength}. Skipping.")
- return None
-
- # Collect VACE refs per *newly generated* segment
- # The orchestrator_log_folder is a good temporary place for these processed VACE refs
- # if they are generated by the orchestrator itself.
- for i in range(num_segments_to_generate):
- # Initial Anchor for VACE
- initial_anchor_path_str_for_vace: str | None = None
- if i == 0 and not task_args.continue_from_video: # Very first segment from scratch
- initial_anchor_path_str_for_vace = task_args.input_images[0]
- elif i > 0 : # Subsequent new segment, initial anchor is previous new segment's end
- # If continue_from_video, input_images[i-1] is the end of (i-1)th *new* segment
- # If not continue_from_video, input_images[i] is the end of (i-1)th *new* segment (since input_images[0] was start of 0th new)
- initial_anchor_path_str_for_vace = task_args.input_images[i] if not task_args.continue_from_video else task_args.input_images[i-1]
-
-
- if initial_anchor_path_str_for_vace and common_args.initial_image_strength > 0.0:
- ref_detail = _prepare_vace_ref_image(initial_anchor_path_str_for_vace, common_args.initial_image_strength, "initial", i, orchestrator_log_folder)
- if ref_detail: vace_image_refs_details.append(ref_detail)
-
- # Final Anchor for VACE (end of current new segment i)
- # If continue_from_video, end anchor is input_images[i]
- # If not, end anchor is input_images[i+1]
- final_anchor_path_str_for_vace = task_args.input_images[i] if task_args.continue_from_video else task_args.input_images[i+1]
- if common_args.final_image_strength > 0.0:
- ref_detail = _prepare_vace_ref_image(final_anchor_path_str_for_vace, common_args.final_image_strength, "final", i, orchestrator_log_folder)
- if ref_detail: vace_image_refs_details.append(ref_detail)
- # --- End VACE Image Refs ---
-
- orchestrator_task_id = generate_unique_task_id(f"sm_travel_orchestrator_{run_id[:8]}_" )
-
- # Handle continue_from_video: download it now if it's a URL, get its path
- initial_video_path_for_headless: str | None = None
- if task_args.continue_from_video:
- dprint(f"Orchestrator: continue_from_video specified: {task_args.continue_from_video}")
- # Download to orchestrator_log_folder, headless will copy/use it from there or orchestrator will specify absolute path.
- # For simplicity, let orchestrator download it to a known sub-folder that headless can expect or is told about.
- # A sub-folder within orchestrator_log_folder is fine.
- downloaded_continued_video_path = _download_video_if_url(
- task_args.continue_from_video,
- orchestrator_log_folder, # Store it within the orchestrator's own log/asset folder
- "continued_video_input"
- )
- if downloaded_continued_video_path and downloaded_continued_video_path.exists():
- initial_video_path_for_headless = str(downloaded_continued_video_path.resolve())
- dprint(f"Orchestrator: Continue video prepared at: {initial_video_path_for_headless}")
- # Optionally, re-encode here if needed, though headless could also do this.
- # For now, assume headless can handle it or will re-encode if necessary.
- else:
- print(f"Error: Could not load or download video from {task_args.continue_from_video}. Cannot proceed with orchestrator task.")
- return 1
-
- orchestrator_payload = {
- "orchestrator_task_id": orchestrator_task_id,
- "run_id": run_id, # For grouping segment task outputs later in headless if needed
- "original_task_args": vars(task_args), # Store original CLI args for this travel task
- "original_common_args": vars(common_args), # Store original common args
- "parsed_resolution_wh": parsed_resolution, # (width, height) tuple
- "main_output_dir_for_run": str(main_output_dir.resolve()), # Base output dir for headless to use
- "orchestrator_log_folder": str(orchestrator_log_folder.resolve()), # For headless to potentially write logs or find assets
-
- "input_image_paths_resolved": [str(Path(p).resolve()) for p in task_args.input_images],
- "continue_from_video_resolved_path": initial_video_path_for_headless, # Path to downloaded/copied video if used
-
- "num_new_segments_to_generate": num_segments_to_generate,
- "base_prompts_expanded": expanded_base_prompts,
- "negative_prompts_expanded": expanded_negative_prompts,
- "segment_frames_expanded": expanded_segment_frames,
- "frame_overlap_expanded": expanded_frame_overlap,
-
- "vace_image_refs_prepared": vace_image_refs_details, # List of dicts with paths to strength-adjusted images
-
- # Pass fade params directly for headless to use when creating guides
- "fade_in_params_json_str": common_args.fade_in_duration,
- "fade_out_params_json_str": common_args.fade_out_duration,
-
- # Other common args that headless might need for segment tasks or final stitch/upscale
- "model_name": common_args.model_name,
- "seed_base": common_args.seed,
- "execution_engine": common_args.execution_engine,
- "use_causvid_lora": common_args.use_causvid_lora,
- "cfg_star_switch": common_args.cfg_star_switch,
- "cfg_zero_step": common_args.cfg_zero_step,
- "params_json_str_override": common_args.params_json_str, # For headless to merge into segment tasks
- "fps_helpers": common_args.fps_helpers, # For guide/stitch tasks in headless
- "last_frame_duplication": common_args.last_frame_duplication,
- "subsequent_starting_strength_adjustment": common_args.subsequent_starting_strength_adjustment,
- "desaturate_subsequent_starting_frames": common_args.desaturate_subsequent_starting_frames,
- "adjust_brightness_subsequent_starting_frames": common_args.adjust_brightness_subsequent_starting_frames,
- "after_first_post_generation_saturation": common_args.after_first_post_generation_saturation,
- "crossfade_sharp_amt": getattr(task_args, 'crossfade_sharp_amt', 0.3), # from travel_args default or value
-
- "upscale_factor": task_args.upscale_factor, # From travel_args
- "upscale_model_name": common_args.upscale_model_name,
-
- "debug_mode_enabled": DEBUG_MODE, # For headless to know if it should run in debug
- "skip_cleanup_enabled": common_args.skip_cleanup # For headless to respect
- }
-
- try:
- # Add the single orchestrator task to the DB
- add_task_to_db(
- task_payload={ # This is the `params` field in the DB
- "orchestrator_details": orchestrator_payload
- },
- db_path_str=db_file_path,
- task_type="travel_orchestrator", # New task type
- task_id_override=orchestrator_task_id # Use the generated ID
- )
- print(f"Successfully enqueued 'travel_orchestrator' task (ID: {orchestrator_task_id}).")
- print(f"Headless.py will now process this sequence.")
- dprint(f"Orchestrator payload submitted: {json.dumps(orchestrator_payload, indent=2, default=str)}")
- except Exception as e_db_add:
- print(f"Failed to add travel_orchestrator task {orchestrator_task_id} to DB: {e_db_add}")
- traceback.print_exc()
- return 1
-
- return 0 # Success (orchestrator task queued)
-
-# Note: The main loop, polling, segment processing, guide video creation,
-# VACE ref application (if moved to headless), stitching, and cleanup
-# are now expected to be handled by headless.py based on the orchestrator task.
-# Functions like _get_unique_target_path, image_to_frame, create_color_frame,
-# _adjust_frame_brightness etc. are now imported from common_utils.
-# Video processing functions like extract_frames_from_video, create_video_from_frames_list,
-# cross_fade_overlap_frames etc. are imported from video_utils.
-# They will be called by headless.py.
-
diff --git a/sm_functions/video_utils.py b/sm_functions/video_utils.py
deleted file mode 100644
index 0b2ebf3b2..000000000
--- a/sm_functions/video_utils.py
+++ /dev/null
@@ -1,337 +0,0 @@
-import math
-import subprocess
-from pathlib import Path
-import traceback
-
-import cv2 # pip install opencv-python
-import numpy as np
-
-# Import dprint from common_utils (assuming it's discoverable)
-from .common_utils import dprint, get_video_frame_count_and_fps # Added get_video_frame_count_and_fps
-
-# --- Easing function for cross-fading ---
-def ease(alpha_lin: float) -> float:
- """cosine ease-in-out (0..1 -> 0..1)"""
- return (1 - math.cos(alpha_lin * math.pi)) / 2.0
-
-def _blend_linear(a: np.ndarray, b: np.ndarray, t: float) -> np.ndarray:
- return cv2.addWeighted(a, 1.0-t, b, t, 0)
-
-def _blend_linear_sharp(a: np.ndarray, b: np.ndarray, t: float, amt: float) -> np.ndarray:
- base = _blend_linear(a,b,t)
- if amt<=0: return base
- blur = cv2.GaussianBlur(base,(0,0),3)
- return cv2.addWeighted(base, 1.0+amt*t, blur, -amt*t, 0)
-
-def cross_fade_overlap_frames(
- segment1_frames: list[np.ndarray],
- segment2_frames: list[np.ndarray],
- overlap_count: int,
- mode: str = "linear_sharp",
- sharp_amt: float = 0.3
-) -> list[np.ndarray]:
- """
- Cross-fades the overlapping frames between two segments using various modes.
-
- Args:
- segment1_frames: Frames from the first segment (video ending)
- segment2_frames: Frames from the second segment (video starting)
- overlap_count: Number of frames to cross-fade
- mode: Blending mode ("linear", "linear_sharp")
- sharp_amt: Sharpening amount for "linear_sharp" mode (0-1)
-
- Returns:
- List of cross-faded frames for the overlap region
- """
- if overlap_count <= 0:
- return []
-
- n = min(overlap_count, len(segment1_frames), len(segment2_frames))
- if n <= 0:
- return []
-
- out_frames = []
- for i in range(n):
- t_linear = (i + 1) / float(n)
- alpha = ease(t_linear)
-
- frame_a_np = segment1_frames[-n+i].astype(np.float32)
- frame_b_np = segment2_frames[i].astype(np.float32)
-
- blended_float: np.ndarray
- if mode == "linear_sharp":
- blended_float = _blend_linear_sharp(frame_a_np, frame_b_np, alpha, sharp_amt)
- elif mode == "linear":
- blended_float = _blend_linear(frame_a_np, frame_b_np, alpha)
- else:
- dprint(f"Warning: Unknown crossfade mode '{mode}'. Defaulting to linear.")
- blended_float = _blend_linear(frame_a_np, frame_b_np, alpha)
-
- blended_uint8 = np.clip(blended_float, 0, 255).astype(np.uint8)
- out_frames.append(blended_uint8)
-
- return out_frames
-
-def extract_frames_from_video(video_path: str | Path, start_frame: int = 0, num_frames: int = None) -> list[np.ndarray]:
- """
- Extracts frames from a video file as numpy arrays.
-
- Args:
- video_path: Path to the video file
- start_frame: Starting frame index (0-based)
- num_frames: Number of frames to extract (None = all remaining frames)
-
- Returns:
- List of frames as BGR numpy arrays
- """
- frames = []
- cap = cv2.VideoCapture(str(video_path))
-
- if not cap.isOpened():
- dprint(f"Error: Could not open video {video_path}")
- return frames
-
- total_frames_video = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
- cap.set(cv2.CAP_PROP_POS_FRAMES, float(start_frame))
-
- frames_to_read = num_frames if num_frames is not None else (total_frames_video - start_frame)
- frames_to_read = min(frames_to_read, total_frames_video - start_frame)
-
- for i in range(frames_to_read):
- ret, frame = cap.read()
- if not ret:
- dprint(f"Warning: Could not read frame {start_frame + i} from {video_path}")
- break
- frames.append(frame)
-
- cap.release()
- return frames
-
-def create_video_from_frames_list(
- frames_list: list[np.ndarray],
- output_path: str | Path,
- fps: int,
- resolution: tuple[int, int]
-) -> Path | None:
- """Creates a video from a list of NumPy BGR frames using FFmpeg subprocess.
- Returns the Path object of the successfully written file, or None if failed.
- """
- output_path_obj = Path(output_path)
- output_path_mp4 = output_path_obj.with_suffix('.mp4')
- output_path_mp4.parent.mkdir(parents=True, exist_ok=True)
-
- ffmpeg_cmd = [
- "ffmpeg", "-y",
- "-loglevel", "error",
- "-f", "rawvideo",
- "-vcodec", "rawvideo",
- "-pix_fmt", "bgr24",
- "-s", f"{resolution[0]}x{resolution[1]}",
- "-r", str(fps),
- "-i", "-",
- "-c:v", "libx264",
- "-pix_fmt", "yuv420p",
- "-preset", "veryfast",
- "-crf", "23",
- str(output_path_mp4.resolve())
- ]
-
- dprint(f"Attempting to write video to: {output_path_mp4} using FFmpeg: {' '.join(ffmpeg_cmd)}")
-
- processed_frames = []
- for frame_idx, frame_np in enumerate(frames_list):
- if frame_np is None:
- dprint(f"Warning: Frame {frame_idx} is None. Skipping.")
- continue
- if not isinstance(frame_np, np.ndarray):
- dprint(f"Warning: Frame {frame_idx} is not a numpy array ({type(frame_np)}). Skipping.")
- continue
- if frame_np.dtype != np.uint8:
- frame_np = frame_np.astype(np.uint8)
- if frame_np.shape[0] != resolution[1] or frame_np.shape[1] != resolution[0] or frame_np.shape[2] != 3:
- dprint(f"Warning: Frame {frame_idx} has incorrect shape {frame_np.shape}, expected ({resolution[1]}, {resolution[0]}, 3). Resizing.")
- try:
- frame_np = cv2.resize(frame_np, resolution, interpolation=cv2.INTER_AREA)
- except Exception as e_resize:
- dprint(f"Error resizing frame {frame_idx}: {e_resize}. Skipping.")
- continue
- processed_frames.append(frame_np)
-
- if not processed_frames:
- dprint("Error: No valid frames to write to video.")
- return None
-
- try:
- raw_video_data = b''.join(frame.tobytes() for frame in processed_frames)
- except Exception as e_data:
- dprint(f"Error creating raw video data: {e_data}")
- return None
-
- try:
- proc = subprocess.run(
- ffmpeg_cmd,
- input=raw_video_data,
- capture_output=True,
- timeout=60
- )
-
- if proc.returncode == 0:
- if output_path_mp4.exists() and output_path_mp4.stat().st_size > 0:
- dprint(f"Generated video (FFmpeg/libx264): {output_path_mp4} ({len(processed_frames)} frames)")
- return output_path_mp4
- else:
- dprint(f"FFmpeg process completed (rc=0) but output file {output_path_mp4} is missing or empty.")
- dprint(f"FFmpeg stdout: {proc.stdout.decode(errors='ignore')}")
- dprint(f"FFmpeg stderr: {proc.stderr.decode(errors='ignore')}")
- return None
- else:
- dprint(f"FFmpeg failed for {output_path_mp4}. Return code: {proc.returncode}")
- dprint(f"FFmpeg stdout: {proc.stdout.decode(errors='ignore')}")
- dprint(f"FFmpeg stderr: {proc.stderr.decode(errors='ignore')}")
- if output_path_mp4.exists():
- try: output_path_mp4.unlink()
- except Exception as e_unlink: dprint(f"Could not remove partially written/failed MP4 file {output_path_mp4}: {e_unlink}")
- return None
-
- except subprocess.TimeoutExpired:
- dprint(f"FFmpeg process timed out after 60 seconds for {output_path_mp4}")
- return None
- except FileNotFoundError:
- dprint("Error: ffmpeg command not found. Please ensure FFmpeg is installed and in your PATH.")
- return None
- except Exception as e_proc:
- dprint(f"Error running FFmpeg subprocess: {e_proc}")
- return None
-
-def _apply_saturation_to_video_ffmpeg(
- input_video_path: str | Path,
- output_video_path: str | Path,
- saturation_level: float,
- preset: str = "veryfast"
-) -> bool:
- """Applies a saturation adjustment to the full video using FFmpeg's eq filter.
- Returns: True if FFmpeg succeeds and the output file exists & is non-empty, else False.
- """
- inp = Path(input_video_path)
- outp = Path(output_video_path)
- outp.parent.mkdir(parents=True, exist_ok=True)
-
- cmd = [
- "ffmpeg", "-y",
- "-i", str(inp.resolve()),
- "-vf", f"eq=saturation={saturation_level}",
- "-c:v", "libx264",
- "-preset", preset,
- "-pix_fmt", "yuv420p",
- "-an",
- str(outp.resolve())
- ]
-
- dprint(f"SATURATION_ADJUST: Running command: {' '.join(cmd)}")
- try:
- subprocess.run(cmd, check=True, capture_output=True, text=True, encoding="utf-8")
- if outp.exists() and outp.stat().st_size > 0:
- return True
- else:
- dprint(f"SATURATION_ADJUST: Output file {outp} missing or empty after ffmpeg run.")
- return False
- except subprocess.CalledProcessError as e:
- print(f"Error during FFmpeg saturation adjust of {inp} -> {outp}:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}")
- return False
-
-def color_match_video_to_reference(source_video_path: str | Path, reference_video_path: str | Path,
- output_video_path: str | Path, parsed_resolution: tuple[int, int]) -> bool:
- """
- Color matches source_video to reference_video using histogram matching on the last frame of reference
- and first frame of source. Applies the transformation to all frames of source_video.
- Returns True if successful, False otherwise.
- """
- try:
- ref_cap = cv2.VideoCapture(str(reference_video_path))
- if not ref_cap.isOpened():
- print(f"Error: Could not open reference video {reference_video_path}")
- return False
-
- ref_frame_count_from_cap, _ = get_video_frame_count_and_fps(str(reference_video_path)) # Use helper
- if ref_frame_count_from_cap is None: ref_frame_count_from_cap = int(ref_cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Fallback
-
- ref_cap.set(cv2.CAP_PROP_POS_FRAMES, float(max(0, ref_frame_count_from_cap - 1)))
- ret, ref_frame = ref_cap.read()
- ref_cap.release()
-
- if not ret or ref_frame is None:
- print(f"Error: Could not read reference frame from {reference_video_path}")
- return False
-
- if ref_frame.shape[1] != parsed_resolution[0] or ref_frame.shape[0] != parsed_resolution[1]:
- ref_frame = cv2.resize(ref_frame, parsed_resolution, interpolation=cv2.INTER_AREA)
-
- src_cap = cv2.VideoCapture(str(source_video_path))
- if not src_cap.isOpened():
- print(f"Error: Could not open source video {source_video_path}")
- return False
-
- src_fps_from_cap = src_cap.get(cv2.CAP_PROP_FPS)
- # src_frame_count_from_cap = int(src_cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Not strictly needed here
-
- ret, src_first_frame = src_cap.read()
- if not ret or src_first_frame is None:
- print(f"Error: Could not read first frame from {source_video_path}")
- src_cap.release()
- return False
-
- if src_first_frame.shape[1] != parsed_resolution[0] or src_first_frame.shape[0] != parsed_resolution[1]:
- src_first_frame = cv2.resize(src_first_frame, parsed_resolution, interpolation=cv2.INTER_AREA)
-
- def match_histogram_channel(source_channel, reference_channel):
- src_hist, _ = np.histogram(source_channel.flatten(), 256, [0, 256])
- ref_hist, _ = np.histogram(reference_channel.flatten(), 256, [0, 256])
- src_cdf = src_hist.cumsum()
- ref_cdf = ref_hist.cumsum()
- src_cdf = src_cdf / src_cdf[-1]
- ref_cdf = ref_cdf / ref_cdf[-1]
- lookup_table = np.zeros(256, dtype=np.uint8)
- for i in range(256):
- closest_idx = np.argmin(np.abs(ref_cdf - src_cdf[i]))
- lookup_table[i] = closest_idx
- return lookup_table
-
- lookup_tables = []
- for channel in range(3):
- lut = match_histogram_channel(src_first_frame[:, :, channel], ref_frame[:, :, channel])
- lookup_tables.append(lut)
-
- output_path_obj = Path(output_video_path)
- output_path_obj.parent.mkdir(parents=True, exist_ok=True)
-
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
- out = cv2.VideoWriter(str(output_path_obj), fourcc, float(src_fps_from_cap), parsed_resolution)
- if not out.isOpened():
- print(f"Error: Could not create output video writer for {output_path_obj}")
- src_cap.release()
- return False
-
- src_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
- frames_processed = 0
- while True:
- ret, frame = src_cap.read()
- if not ret: break
- if frame.shape[1] != parsed_resolution[0] or frame.shape[0] != parsed_resolution[1]:
- frame = cv2.resize(frame, parsed_resolution, interpolation=cv2.INTER_AREA)
- matched_frame = frame.copy()
- for channel in range(3):
- matched_frame[:, :, channel] = cv2.LUT(frame[:, :, channel], lookup_tables[channel])
- out.write(matched_frame)
- frames_processed += 1
-
- src_cap.release()
- out.release()
-
- dprint(f"Color matching complete: {frames_processed} frames processed from {source_video_path} to {output_video_path}")
- return True
-
- except Exception as e:
- print(f"Error during color matching: {e}")
- traceback.print_exc()
- return False
\ No newline at end of file
diff --git a/source/__init__.py b/source/__init__.py
new file mode 100644
index 000000000..0519ecba6
--- /dev/null
+++ b/source/__init__.py
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/source/common_utils.py b/source/common_utils.py
new file mode 100644
index 000000000..cd12eba54
--- /dev/null
+++ b/source/common_utils.py
@@ -0,0 +1,2444 @@
+"""Common utility functions and constants for steerable_motion tasks."""
+
+import json
+import math
+import os
+import shutil
+import sqlite3
+import subprocess
+import tempfile
+import time
+import traceback
+import uuid
+from pathlib import Path
+from datetime import datetime
+from typing import Any, Generator
+
+import cv2 # pip install opencv-python
+import mediapipe as mp # pip install mediapipe
+import numpy as np # pip install numpy
+from PIL import Image, ImageDraw, ImageFont, ImageEnhance # pip install Pillow, ensure ImageEnhance is imported
+import requests # Added for downloads
+from urllib.parse import urlparse # Added for URL parsing
+import urllib.parse
+
+# --- Global Debug Mode ---
+# This will be set by the main script (steerable_motion.py)
+DEBUG_MODE = False
+
+# --- Constants for DB interaction and defaults ---
+STATUS_QUEUED = "Queued"
+STATUS_IN_PROGRESS = "In Progress"
+STATUS_COMPLETE = "Complete"
+STATUS_FAILED = "Failed"
+DEFAULT_DB_TABLE_NAME = "tasks"
+# DEFAULT_MODEL_NAME = "vace_14B" # Defined in steerable_motion.py's argparser
+# DEFAULT_SEGMENT_FRAMES = 81 # Defined in steerable_motion.py's argparser
+# DEFAULT_FPS_HELPERS = 25 # Defined in steerable_motion.py's argparser
+# DEFAULT_SEED = 12345 # Defined in steerable_motion.py's argparser
+
+# --- Debug / Verbose Logging Helper ---
+def dprint(msg: str):
+ """Print a debug message if DEBUG_MODE is enabled."""
+ if DEBUG_MODE:
+ print(f"[DEBUG SM-COMMON {datetime.utcnow().isoformat()}Z] {msg}")
+
+# --- Helper Functions ---
+
+def snap_resolution_to_model_grid(parsed_res: tuple[int, int]) -> tuple[int, int]:
+ """
+ Snaps resolution to model grid requirements (multiples of 16).
+
+ Args:
+ parsed_res: (width, height) tuple
+
+ Returns:
+ (width, height) tuple snapped to nearest valid values
+ """
+ width, height = parsed_res
+ # Ensure resolution is compatible with model requirements (multiples of 16)
+ width = (width // 16) * 16
+ height = (height // 16) * 16
+ return width, height
+
+def ensure_valid_prompt(prompt: str | None) -> str:
+ """
+ Ensures prompt is valid (not None or empty), returns space as default.
+
+ Args:
+ prompt: Input prompt string or None
+
+ Returns:
+ Valid prompt string (space if input was None/empty)
+ """
+ if not prompt or not prompt.strip():
+ return " "
+ return prompt.strip()
+
+def ensure_valid_negative_prompt(negative_prompt: str | None) -> str:
+ """
+ Ensures negative prompt is valid (not None), returns space as default.
+
+ Args:
+ negative_prompt: Input negative prompt string or None
+
+ Returns:
+ Valid negative prompt string (space if input was None/empty)
+ """
+ if not negative_prompt or not negative_prompt.strip():
+ return " "
+ return negative_prompt.strip()
+
+def process_additional_loras_shared(
+ additional_loras: dict[str, float | str],
+ wgp_mod: Any,
+ model_filename_for_task: str,
+ task_id: str,
+ dprint: callable
+) -> dict[str, str]:
+ """
+ Shared function to process additional LoRAs for any task handler.
+
+ Args:
+ additional_loras: Dict mapping LoRA paths to strength values
+ wgp_mod: WGP module instance
+ model_filename_for_task: Model filename for LoRA directory lookup
+ task_id: Task ID for logging
+ dprint: Debug print function
+
+ Returns:
+ Dict mapping LoRA filenames to strength strings
+ """
+ processed_additional_loras: dict[str, str] = {}
+
+ if not additional_loras:
+ return processed_additional_loras
+
+ try:
+ import urllib.parse
+ base_lora_dir_for_model = Path(wgp_mod.get_lora_dir(model_filename_for_task))
+ base_lora_dir_for_model.mkdir(parents=True, exist_ok=True)
+ wan2gp_lora_root = Path("Wan2GP/loras")
+ wan2gp_lora_root.mkdir(parents=True, exist_ok=True)
+
+ for lora_path, lora_strength in additional_loras.items():
+ try:
+ # Derive filename
+ if lora_path.startswith("http://") or lora_path.startswith("https://"):
+ lora_filename = Path(urllib.parse.urlparse(lora_path).path).name
+ else:
+ lora_filename = Path(lora_path).name
+
+ target_path_in_model_dir = base_lora_dir_for_model / lora_filename
+
+ # --------------------------------------------------
+ # 1) Handle remote URL – download if not already cached
+ # --------------------------------------------------
+ if lora_path.startswith("http://") or lora_path.startswith("https://"):
+ dl_target_path = wan2gp_lora_root / lora_filename
+
+ # Some LoRA hubs publish weights under extremely generic
+ # names (e.g. "pytorch_model.bin", "model.safetensors").
+ # If we already have a file with that generic name cached
+ # in Wan2GP/loras **we reuse it** instead of downloading a
+ # potentially different LoRA into the same filename which
+ # would be ambiguous.
+ GENERIC_NAMES = {
+ "pytorch_model.bin",
+ "pytorch.bin",
+ "model.safetensors",
+ "pytorch_model.safetensors",
+ "diffusion_pytorch_model.safetensors",
+ }
+
+ if lora_filename in GENERIC_NAMES and dl_target_path.exists():
+ dprint(f"Task {task_id}: Generic-named LoRA '{lora_filename}' already present – using cached copy {dl_target_path} instead of re-downloading.")
+ else:
+ if not dl_target_path.exists():
+ dprint(f"Task {task_id}: Downloading LoRA from {lora_path} to {dl_target_path}")
+ try:
+ download_file(lora_path, wan2gp_lora_root, lora_filename)
+ except Exception as e_dl:
+ dprint(f"Task {task_id}: ERROR – failed to download LoRA {lora_path}: {e_dl}")
+ continue # Skip this LoRA
+ else:
+ dprint(f"Task {task_id}: LoRA already downloaded at {dl_target_path}")
+
+ # Copy into model-specific dir if needed
+ if not target_path_in_model_dir.exists():
+ try:
+ shutil.copy(str(dl_target_path), str(target_path_in_model_dir))
+ dprint(f"Task {task_id}: Copied LoRA to model dir {target_path_in_model_dir}")
+ except Exception as e_cp:
+ dprint(f"Task {task_id}: WARNING – could not copy LoRA to model dir: {e_cp}")
+ else:
+ # --------------------------------------------------
+ # 2) Handle local / relative path
+ # --------------------------------------------------
+ src_path = Path(lora_path)
+ if not src_path.is_absolute():
+ # Try relative to CWD
+ src_path = Path.cwd() / src_path
+ if not src_path.exists():
+ dprint(f"Task {task_id}: WARNING – local LoRA path not found: {src_path}")
+ continue
+ # Ensure LoRA is available inside model dir
+ if not target_path_in_model_dir.exists() or src_path.resolve() != target_path_in_model_dir.resolve():
+ try:
+ shutil.copy(str(src_path), str(target_path_in_model_dir))
+ dprint(f"Task {task_id}: Copied local LoRA {src_path} to {target_path_in_model_dir}")
+ except Exception as e_cploc:
+ dprint(f"Task {task_id}: WARNING – could not copy local LoRA to model dir: {e_cploc}")
+
+ # Register for WGP payload
+ processed_additional_loras[lora_filename] = str(lora_strength)
+ except Exception as e_lora_iter:
+ dprint(f"Task {task_id}: Error processing LoRA entry '{lora_path}': {e_lora_iter}")
+ except Exception as e_outer:
+ dprint(f"Task {task_id}: Error during additional LoRA processing: {e_outer}")
+
+ return processed_additional_loras
+
+def parse_resolution(res_str: str) -> tuple[int, int]:
+ """Parses 'WIDTHxHEIGHT' string to (width, height) tuple."""
+ try:
+ w, h = map(int, res_str.split('x'))
+ if w <= 0 or h <= 0:
+ raise ValueError("Width and height must be positive.")
+ return w, h
+ except ValueError as e:
+ raise ValueError(f"Resolution string must be in WIDTHxHEIGHT format with positive integers (e.g., '960x544'), got {res_str}. Error: {e}")
+
+def generate_unique_task_id(prefix: str = "") -> str:
+ """Generates a UUID4 string.
+
+ The optional *prefix* parameter is now ignored so that the returned value
+ is a bare RFC-4122 UUID which can be stored in a Postgres `uuid` column
+ without casting errors. The argument is kept in the signature to avoid
+ breaking existing call-sites that still pass a prefix.
+ """
+ return str(uuid.uuid4())
+
+def image_to_frame(image_path: str | Path, target_size: tuple[int, int]) -> np.ndarray | None:
+ """Loads an image, resizes it, and converts to BGR NumPy array for OpenCV."""
+ try:
+ img = Image.open(image_path).convert("RGB")
+ img = img.resize(target_size, Image.Resampling.LANCZOS)
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ except FileNotFoundError:
+ print(f"Error: Image file not found at {image_path}")
+ return None
+ except Exception as e:
+ print(f"Error loading or processing image {image_path}: {e}")
+ return None
+
+def create_color_frame(size: tuple[int, int], color_bgr: tuple[int, int, int] = (0, 0, 0)) -> np.ndarray:
+ """Creates a single color BGR frame (default black)."""
+ height, width = size[1], size[0] # size is (width, height)
+ frame = np.full((height, width, 3), color_bgr, dtype=np.uint8)
+ return frame
+
+def get_easing_function(name: str):
+ """
+ Returns an easing function by name.
+ """
+ if name == 'linear':
+ return lambda t: t
+ elif name == 'ease_in_quad':
+ return lambda t: t * t
+ elif name == 'ease_out_quad':
+ return lambda t: t * (2 - t)
+ elif name == 'ease_in_out_quad' or name == 'ease_in_out': # Added alias
+ return lambda t: 2 * t * t if t < 0.5 else -1 + (4 - 2 * t) * t
+ elif name == 'ease_in_cubic':
+ return lambda t: t * t * t
+ elif name == 'ease_out_cubic':
+ return lambda t: 1 + ((t - 1) ** 3)
+ elif name == 'ease_in_out_cubic':
+ return lambda t: 4 * t * t * t if t < 0.5 else 1 + ((2 * t - 2) ** 3) / 2
+ elif name == 'ease_in_quart':
+ return lambda t: t * t * t * t
+ elif name == 'ease_out_quart':
+ return lambda t: 1 - ((t - 1) ** 4)
+ elif name == 'ease_in_out_quart':
+ return lambda t: 8 * t * t * t * t if t < 0.5 else 1 - ((-2 * t + 2) ** 4) / 2
+ elif name == 'ease_in_quint':
+ return lambda t: t * t * t * t * t
+ elif name == 'ease_out_quint':
+ return lambda t: 1 + ((t - 1) ** 5)
+ elif name == 'ease_in_out_quint':
+ return lambda t: 16 * t * t * t * t * t if t < 0.5 else 1 + ((-2 * t + 2) ** 5) / 2
+ elif name == 'ease_in_sine':
+ return lambda t: 1 - math.cos(t * math.pi / 2)
+ elif name == 'ease_out_sine':
+ return lambda t: math.sin(t * math.pi / 2)
+ elif name == 'ease_in_out_sine':
+ return lambda t: -(math.cos(math.pi * t) - 1) / 2
+ elif name == 'ease_in_expo':
+ return lambda t: 0 if t == 0 else 2 ** (10 * (t - 1))
+ elif name == 'ease_out_expo':
+ return lambda t: 1 if t == 1 else 1 - 2 ** (-10 * t)
+ elif name == 'ease_in_out_expo':
+ if t == 0: return 0
+ if t == 1: return 1
+ if t < 0.5:
+ return (2 ** (20 * t - 10)) / 2
+ else:
+ return (2 - 2 ** (-20 * t + 10)) / 2
+ else: # Default to ease_in_out
+ return lambda t: 2 * t * t if t < 0.5 else -1 + (4 - 2 * t) * t
+
+def create_video_from_frames_list(
+ frames_list: list[np.ndarray],
+ output_path: str | Path,
+ fps: int,
+ resolution: tuple[int, int] # width, height
+):
+ """Creates an MP4 video from a list of NumPy BGR frames."""
+ output_path_obj = Path(output_path)
+ output_path_mp4 = output_path_obj.with_suffix('.mp4')
+ output_path_mp4.parent.mkdir(parents=True, exist_ok=True)
+
+ ffmpeg_cmd = [
+ "ffmpeg", "-y",
+ "-loglevel", "error",
+ "-f", "rawvideo",
+ "-vcodec", "rawvideo",
+ "-pix_fmt", "bgr24",
+ "-s", f"{resolution[0]}x{resolution[1]}",
+ "-r", str(fps),
+ "-i", "-",
+ "-c:v", "libx264",
+ "-pix_fmt", "yuv420p",
+ "-preset", "veryfast",
+ "-crf", "23",
+ "-vf", "format=yuv420p,colorspace=bt709:iall=bt709:fast=1",
+ "-color_primaries", "bt709",
+ "-color_trc", "bt709",
+ "-colorspace", "bt709",
+ str(output_path_mp4.resolve())
+ ]
+
+ processed_frames = []
+ for frame_idx, frame_np in enumerate(frames_list):
+ if frame_np is None or not isinstance(frame_np, np.ndarray):
+ continue
+ if frame_np.dtype != np.uint8:
+ frame_np = frame_np.astype(np.uint8)
+ if frame_np.shape[0] != resolution[1] or frame_np.shape[1] != resolution[0] or frame_np.shape[2] != 3:
+ try:
+ frame_np = cv2.resize(frame_np, resolution, interpolation=cv2.INTER_AREA)
+ except Exception:
+ continue
+ processed_frames.append(frame_np)
+
+ if not processed_frames:
+ print(f"No valid frames to write for {output_path_mp4}")
+ return None
+
+ try:
+ raw_video_data = b''.join(frame.tobytes() for frame in processed_frames)
+ except Exception as e:
+ print(f"Error preparing video data: {e}")
+ return None
+
+ if not raw_video_data:
+ print(f"No video data to write for {output_path_mp4}")
+ return None
+
+ try:
+ process = subprocess.Popen(ffmpeg_cmd, stdin=subprocess.PIPE)
+ process.communicate(input=raw_video_data)
+ if process.returncode != 0:
+ print(f"FFmpeg failed with return code {process.returncode}")
+ return None
+ return output_path_mp4
+ except Exception as e:
+ print(f"Error during FFmpeg encoding: {e}")
+ return None
+
+def add_task_to_db(task_payload: dict, db_path: str | Path, task_type_str: str, dependant_on: str | None = None):
+ conn = sqlite3.connect(str(db_path))
+ cursor = conn.cursor()
+ try:
+ # Ensure task_type is not in the params dict as it's now a separate column
+ # This is important if the caller still includes it in task_payload by old habit.
+ # The task_payload itself is still stored in the 'params' column.
+ # The task_id is expected to be in task_payload.
+ # The task_id for the DB record should be consistent.
+ # If task_payload (which becomes params) already has a task_id, use that.
+ # Otherwise, a task_id needs to be generated/provided.
+ # The original line was: task_payload["task_id"]
+ # This means the task_id to be inserted was expected to be part of the task_payload dict.
+
+ # Let's assume the task_id passed to this function (if any) or generated by it
+ # is the PRIMARY KEY for the DB.
+ # The `task_payload` argument to this function is what gets stored in the `params` column.
+
+ current_task_id = task_payload.get("task_id") # This is the task_id from the original payload structure.
+ if not current_task_id:
+ # This case should ideally not happen if `steerable_motion.py` prepares `task_payload` correctly.
+ # For safety, one might generate one, but it's better to ensure the caller provides it.
+ # Sticking to the original structure where task_payload["task_id"] was used:
+ raise ValueError("task_id must be present in task_payload for add_task_to_db")
+
+ headless_params_dict = task_payload.copy() # Work on a copy
+ if "task_type" in headless_params_dict:
+ del headless_params_dict["task_type"] # Remove if it exists, to avoid redundancy
+
+ # The task_id for the DB record should be consistent.
+ # If task_payload (which becomes params) already has a task_id, use that.
+ # Otherwise, a task_id needs to be generated/provided.
+ # The original line was: task_payload["task_id"]
+ # This means the task_id to be inserted was expected to be part of the task_payload dict.
+
+ # Let's assume the task_id passed to this function (if any) or generated by it
+ # is the PRIMARY KEY for the DB.
+ # The `task_payload` argument to this function is what gets stored in the `params` column.
+
+ params_json_for_db = json.dumps(headless_params_dict)
+ current_timestamp = datetime.utcnow().isoformat() + "Z"
+ project_id = task_payload.get("project_id", "default_project_id") # Get project_id or use default
+
+ cursor.execute(
+ f"INSERT INTO {DEFAULT_DB_TABLE_NAME} (id, params, task_type, status, created_at, project_id, dependant_on) VALUES (?, ?, ?, ?, ?, ?, ?)",
+ (current_task_id, params_json_for_db, task_type_str, STATUS_QUEUED, current_timestamp, project_id, dependant_on)
+ )
+ conn.commit()
+ print(f"Task {current_task_id} (Type: {task_type_str}) added to database {db_path}.")
+ except sqlite3.Error as e:
+ # Use current_task_id in error message if available
+ task_id_for_error = task_payload.get("task_id", "UNKNOWN_TASK_ID")
+ print(f"SQLite error when adding task {task_id_for_error} (Type: {task_type_str}): {e}")
+ raise
+ finally:
+ conn.close()
+
+def poll_task_status(task_id: str, db_path: str | Path, poll_interval_seconds: int = 10, timeout_seconds: int = 1800) -> str | None:
+ """Polls the DB for task completion and returns the output_location."""
+ print(f"Polling for completion of task {task_id} (timeout: {timeout_seconds}s)...")
+ start_time = time.time()
+ last_status_print_time = 0
+
+ while True:
+ current_time = time.time()
+ if current_time - start_time > timeout_seconds:
+ print(f"Error: Timeout polling for task {task_id} after {timeout_seconds} seconds.")
+ return None
+
+ conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
+ cursor = conn.cursor()
+ try:
+ cursor.execute(f"SELECT status, output_location FROM {DEFAULT_DB_TABLE_NAME} WHERE id = ?", (task_id,))
+ row = cursor.fetchone()
+ except sqlite3.Error as e:
+ print(f"SQLite error while polling task {task_id}: {e}. Retrying...")
+ conn.close()
+ time.sleep(min(poll_interval_seconds, 5)) # Shorter sleep on DB error
+ continue
+ finally:
+ conn.close()
+
+ if row:
+ status = row["status"]
+ output_location = row["output_location"]
+
+ if current_time - last_status_print_time > poll_interval_seconds * 2 : # Print status periodically
+ print(f"Task {task_id}: Status = {status} (Output: {output_location if output_location else 'N/A'})")
+ last_status_print_time = current_time
+
+ if status == STATUS_COMPLETE:
+ if output_location:
+ print(f"Task {task_id} completed successfully. Output: {output_location}")
+ return output_location
+ else:
+ print(f"Error: Task {task_id} is COMPLETE but output_location is missing. Assuming failure.")
+ return None
+ elif status == STATUS_FAILED:
+ print(f"Error: Task {task_id} failed.")
+ return None
+ elif status not in [STATUS_QUEUED, STATUS_IN_PROGRESS]:
+ print(f"Warning: Task {task_id} has unknown status '{status}'. Treating as error.")
+ return None
+ else:
+ if current_time - last_status_print_time > poll_interval_seconds * 2 :
+ print(f"Task {task_id}: Not found in DB yet or status pending...")
+ last_status_print_time = current_time
+
+ time.sleep(poll_interval_seconds)
+
+def extract_video_segment_ffmpeg(
+ input_video_path: str | Path,
+ output_video_path: str | Path,
+ start_frame_index: int, # 0-indexed
+ num_frames_to_keep: int,
+ input_fps: float, # FPS of the input video for accurate -ss calculation
+ resolution: tuple[int,int]
+):
+ """Extracts a video segment using FFmpeg with stream copy if possible."""
+ dprint(f"EXTRACT_VIDEO_SEGMENT_FFMPEG: Called with input='{input_video_path}', output='{output_video_path}', start_idx={start_frame_index}, num_frames={num_frames_to_keep}, input_fps={input_fps}")
+ if num_frames_to_keep <= 0:
+ print(f"Warning: num_frames_to_keep is {num_frames_to_keep} for {output_video_path} (FFmpeg). Nothing to extract.")
+ dprint("EXTRACT_VIDEO_SEGMENT_FFMPEG: num_frames_to_keep is 0 or less, returning.")
+ Path(output_video_path).unlink(missing_ok=True)
+ return
+
+ input_video_path = Path(input_video_path)
+ output_video_path = Path(output_video_path)
+ output_video_path.parent.mkdir(parents=True, exist_ok=True)
+
+ start_time_seconds = start_frame_index / input_fps
+
+ cmd = [
+ 'ffmpeg',
+ '-y',
+ '-ss', str(start_time_seconds),
+ '-i', str(input_video_path.resolve()),
+ '-vframes', str(num_frames_to_keep),
+ '-an',
+ str(output_video_path.resolve())
+ ]
+
+ dprint(f"EXTRACT_VIDEO_SEGMENT_FFMPEG: Running command: {' '.join(cmd)}")
+ try:
+ process = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
+ dprint(f"EXTRACT_VIDEO_SEGMENT_FFMPEG: Successfully extracted segment to {output_video_path}")
+ if process.stderr:
+ dprint(f"FFmpeg stderr (for {output_video_path}):\n{process.stderr}")
+ if not output_video_path.exists() or output_video_path.stat().st_size == 0:
+ print(f"Error: FFmpeg command for {output_video_path} apparently succeeded but output file is missing or empty.")
+ dprint(f"FFmpeg command for {output_video_path} produced no output. stdout:\n{process.stdout}\nstderr:\n{process.stderr}")
+
+ except subprocess.CalledProcessError as e:
+ print(f"Error during FFmpeg segment extraction for {output_video_path}:")
+ print("FFmpeg command:", ' '.join(e.cmd))
+ if e.stdout: print("FFmpeg stdout:\n", e.stdout)
+ if e.stderr: print("FFmpeg stderr:\n", e.stderr)
+ dprint(f"FFmpeg extraction failed for {output_video_path}. Error: {e}")
+ Path(output_video_path).unlink(missing_ok=True)
+ except FileNotFoundError:
+ print("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in your PATH.")
+ dprint("FFmpeg command not found during segment extraction.")
+ raise
+
+def stitch_videos_ffmpeg(video_paths_list: list[str | Path], output_path: str | Path):
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ if not video_paths_list:
+ print("No videos to stitch.")
+ return
+
+ valid_video_paths = []
+ for p in video_paths_list:
+ resolved_p = Path(p).resolve()
+ if resolved_p.exists() and resolved_p.stat().st_size > 0:
+ valid_video_paths.append(resolved_p)
+ else:
+ print(f"Warning: Video segment {resolved_p} is missing or empty. Skipping from stitch list.")
+
+ if not valid_video_paths:
+ print("No valid video segments found to stitch after checks.")
+ return
+
+ with tempfile.TemporaryDirectory(prefix="ffmpeg_concat_") as tmpdir:
+ filelist_path = Path(tmpdir) / "ffmpeg_filelist.txt"
+ with open(filelist_path, 'w', encoding='utf-8') as f:
+ for video_path in valid_video_paths:
+ f.write(f"file '{video_path.as_posix()}'\n")
+
+ cmd = [
+ 'ffmpeg', '-y', '-f', 'concat', '-safe', '0',
+ '-i', str(filelist_path),
+ '-c', 'copy', str(output_path)
+ ]
+
+ print(f"Running ffmpeg to stitch videos: {' '.join(cmd)}")
+ try:
+ process = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
+ print(f"Successfully stitched videos into: {output_path}")
+ if process.stderr: print("FFmpeg log (stderr):\n", process.stderr)
+ except subprocess.CalledProcessError as e:
+ print(f"Error during ffmpeg stitching for {output_path}:")
+ print("FFmpeg command:", ' '.join(e.cmd))
+ if e.stdout: print("FFmpeg stdout:\n", e.stdout)
+ if e.stderr: print("FFmpeg stderr:\n", e.stderr)
+ raise
+ except FileNotFoundError:
+ print("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in your PATH.")
+ raise
+
+def save_frame_from_video(video_path: Path, frame_index: int, output_image_path: Path, resolution: tuple[int, int]):
+ """Extracts a specific frame from a video, resizes, and saves it as an image."""
+ dprint(f"SAVE_FRAME_FROM_VIDEO: Input='{video_path}', Index={frame_index}, Output='{output_image_path}', Res={resolution}")
+ if not video_path.exists() or video_path.stat().st_size == 0:
+ print(f"Error: Video file for frame extraction not found or empty: {video_path}")
+ return False
+
+ cap = cv2.VideoCapture(str(video_path))
+ if not cap.isOpened():
+ print(f"Error: Could not open video file: {video_path}")
+ return False
+
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ # Allow Python‐style negative indexing (e.g. -1 for last frame)
+ if frame_index < 0:
+ frame_index = total_frames + frame_index # Convert to positive index
+
+ if frame_index < 0 or frame_index >= total_frames:
+ print(f"Error: Frame index {frame_index} is out of bounds for video {video_path} (total frames: {total_frames}).")
+ cap.release()
+ return False
+
+ cap.set(cv2.CAP_PROP_POS_FRAMES, float(frame_index))
+ ret, frame = cap.read()
+ cap.release()
+
+ if not ret or frame is None:
+ print(f"Error: Could not read frame {frame_index} from {video_path}.")
+ return False
+
+ try:
+ if frame.shape[1] != resolution[0] or frame.shape[0] != resolution[1]:
+ dprint(f"SAVE_FRAME_FROM_VIDEO: Resizing frame from {frame.shape[:2]} to {resolution[:2][::-1]}")
+ frame = cv2.resize(frame, resolution, interpolation=cv2.INTER_AREA)
+
+ output_image_path.parent.mkdir(parents=True, exist_ok=True)
+ cv2.imwrite(str(output_image_path), frame)
+ print(f"Successfully saved frame {frame_index} from {video_path} to {output_image_path}")
+ return True
+ except Exception as e:
+ print(f"Error saving frame to {output_image_path}: {e}")
+ traceback.print_exc()
+ return False
+
+# --- FFMPEG-based specific frame extraction ---
+def extract_specific_frame_ffmpeg(
+ input_video_path: str | Path,
+ frame_number: int, # 0-indexed
+ output_image_path: str | Path,
+ input_fps: float # Passed by caller, though not strictly needed for ffmpeg frame index selection using 'eq(n,frame_number)'
+):
+ """Extracts a specific frame from a video using FFmpeg and saves it as an image."""
+ dprint(f"EXTRACT_SPECIFIC_FRAME_FFMPEG: Input='{input_video_path}', Frame={frame_number}, Output='{output_image_path}'")
+ input_video_p = Path(input_video_path)
+ output_image_p = Path(output_image_path)
+ output_image_p.parent.mkdir(parents=True, exist_ok=True)
+
+ if not input_video_p.exists() or input_video_p.stat().st_size == 0:
+ print(f"Error: Input video for frame extraction not found or empty: {input_video_p}")
+ dprint(f"EXTRACT_SPECIFIC_FRAME_FFMPEG: Input video {input_video_p} not found or empty. Returning False.")
+ return False
+
+ cmd = [
+ 'ffmpeg',
+ '-y', # Overwrite output without asking
+ '-i', str(input_video_p.resolve()),
+ '-vf', f"select=eq(n\,{frame_number})", # Escaped comma for ffmpeg filter syntax
+ '-vframes', '1',
+ str(output_image_p.resolve())
+ ]
+
+ dprint(f"EXTRACT_SPECIFIC_FRAME_FFMPEG: Running command: {' '.join(cmd)}")
+ try:
+ process = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
+ dprint(f"EXTRACT_SPECIFIC_FRAME_FFMPEG: Successfully extracted frame {frame_number} to {output_image_p}")
+ if process.stderr:
+ dprint(f"FFmpeg stderr (for frame extraction to {output_image_p}):\n{process.stderr}")
+ if not output_image_p.exists() or output_image_p.stat().st_size == 0:
+ print(f"Error: FFmpeg command for frame extraction to {output_image_p} apparently succeeded but output file is missing or empty.")
+ dprint(f"FFmpeg command for {output_image_p} (frame extraction) produced no output. stdout:\n{process.stdout}\nstderr:\n{process.stderr}")
+ return False
+ return True
+ except subprocess.CalledProcessError as e:
+ print(f"Error during FFmpeg frame extraction for {output_image_p}:")
+ print("FFmpeg command:", ' '.join(e.cmd))
+ if e.stdout: print("FFmpeg stdout:\n", e.stdout)
+ if e.stderr: print("FFmpeg stderr:\n", e.stderr)
+ dprint(f"FFmpeg frame extraction failed for {output_image_p}. Error: {e}")
+ if output_image_p.exists(): output_image_p.unlink(missing_ok=True)
+ return False
+ except FileNotFoundError:
+ print("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in your PATH.")
+ dprint("FFmpeg command not found during frame extraction.")
+ raise
+
+# --- FFMPEG-based video concatenation (alternative to stitch_videos_ffmpeg if caller manages temp dir) ---
+def concatenate_videos_ffmpeg(
+ video_paths: list[str | Path],
+ output_path: str | Path,
+ temp_dir_for_list: str | Path # Directory where the list file will be created
+):
+ """Concatenates multiple video files into one using FFmpeg, using a provided temp directory for the list file."""
+ output_p = Path(output_path)
+ output_p.parent.mkdir(parents=True, exist_ok=True)
+ temp_dir_p = Path(temp_dir_for_list)
+ temp_dir_p.mkdir(parents=True, exist_ok=True)
+
+ if not video_paths:
+ print("No videos to concatenate.")
+ dprint("CONCATENATE_VIDEOS_FFMPEG: No video paths provided. Returning.")
+ if output_p.exists(): output_p.unlink(missing_ok=True)
+ return
+
+ valid_video_paths = []
+ for p_item in video_paths:
+ resolved_p_item = Path(p_item).resolve()
+ if resolved_p_item.exists() and resolved_p_item.stat().st_size > 0:
+ valid_video_paths.append(resolved_p_item)
+ else:
+ print(f"Warning: Video segment {resolved_p_item} for concatenation is missing or empty. Skipping.")
+ dprint(f"CONCATENATE_VIDEOS_FFMPEG: Skipping invalid video segment {resolved_p_item}")
+
+ if not valid_video_paths:
+ print("No valid video segments found to concatenate after checks.")
+ dprint("CONCATENATE_VIDEOS_FFMPEG: No valid video segments. Returning.")
+ if output_p.exists(): output_p.unlink(missing_ok=True)
+ return
+
+ filelist_path = temp_dir_p / "ffmpeg_concat_filelist.txt"
+ with open(filelist_path, 'w', encoding='utf-8') as f:
+ for video_path_item in valid_video_paths:
+ f.write(f"file '{video_path_item.as_posix()}'\n") # Use as_posix() for ffmpeg list file
+
+ cmd = [
+ 'ffmpeg', '-y',
+ '-f', 'concat',
+ '-safe', '0',
+ '-i', str(filelist_path.resolve()),
+ '-c', 'copy',
+ str(output_p.resolve())
+ ]
+
+ dprint(f"CONCATENATE_VIDEOS_FFMPEG: Running command: {' '.join(cmd)} with list file {filelist_path}")
+ try:
+ process = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
+ print(f"Successfully concatenated videos into: {output_p}")
+ dprint(f"CONCATENATE_VIDEOS_FFMPEG: Success. Output: {output_p}")
+ if process.stderr:
+ dprint(f"FFmpeg stderr (for concatenation to {output_p}):\n{process.stderr}")
+ if not output_p.exists() or output_p.stat().st_size == 0:
+ print(f"Warning: FFmpeg concatenation to {output_p} apparently succeeded but output file is missing or empty.")
+ dprint(f"FFmpeg command for {output_p} (concatenation) produced no output. stdout:\n{process.stdout}\nstderr:\n{process.stderr}")
+ except subprocess.CalledProcessError as e:
+ print(f"Error during FFmpeg concatenation for {output_p}:")
+ print("FFmpeg command:", ' '.join(e.cmd))
+ if e.stdout: print("FFmpeg stdout:\n", e.stdout)
+ if e.stderr: print("FFmpeg stderr:\n", e.stderr)
+ dprint(f"FFmpeg concatenation failed for {output_p}. Error: {e}")
+ if output_p.exists(): output_p.unlink(missing_ok=True)
+ raise
+ except FileNotFoundError:
+ print("Error: ffmpeg command not found. Please ensure ffmpeg is installed and in your PATH.")
+ dprint("CONCATENATE_VIDEOS_FFMPEG: ffmpeg command not found.")
+ raise
+
+# --- OpenCV-based video properties extraction ---
+def get_video_frame_count_and_fps(video_path: str | Path) -> tuple[int | None, float | None]:
+ """Gets frame count and FPS of a video using OpenCV. Returns (None, None) on failure."""
+ video_path_str = str(Path(video_path).resolve())
+ cap = None
+ try:
+ cap = cv2.VideoCapture(video_path_str)
+ if not cap.isOpened():
+ dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Could not open video: {video_path_str}")
+ return None, None
+
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ fps = cap.get(cv2.CAP_PROP_FPS)
+
+ # Validate frame_count and fps as they can sometimes be 0 or negative for problematic files/streams
+ valid_frame_count = frame_count if frame_count > 0 else None
+ valid_fps = fps if fps > 0 else None
+
+ if valid_frame_count is None:
+ dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Video {video_path_str} reported non-positive frame count: {frame_count}. Treating as unknown.")
+ if valid_fps is None:
+ dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Video {video_path_str} reported non-positive FPS: {fps}. Treating as unknown.")
+
+ dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Video {video_path_str} - Frames: {valid_frame_count}, FPS: {valid_fps}")
+ return valid_frame_count, valid_fps
+ except Exception as e:
+ dprint(f"GET_VIDEO_FRAME_COUNT_FPS: Exception processing {video_path_str}: {e}")
+ return None, None
+ finally:
+ if cap:
+ cap.release()
+
+
+body_colors = [
+ [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]
+]
+face_color = [255, 255, 255]
+hand_keypoint_color = [0, 0, 255]
+hand_limb_colors = [
+ [255,0,0],[255,60,0],[255,120,0],[255,180,0], [180,255,0],[120,255,0],[60,255,0],[0,255,0],
+ [0,255,60],[0,255,120],[0,255,180],[0,180,255], [0,120,255],[0,60,255],[0,0,255],[60,0,255],
+ [120,0,255],[180,0,255],[255,0,180],[255,0,120]
+]
+
+# MediaPipe Pose connections (33 landmarks, indices 0-32)
+# Based on mp.solutions.pose.POSE_CONNECTIONS
+body_skeleton = [
+ (0, 1), (1, 2), (2, 3), (3, 7), # Nose to Left Eye to Left Ear
+ (0, 4), (4, 5), (5, 6), (6, 8), # Nose to Right Eye to Right Ear
+ (9, 10), # Mouth
+ (11, 12), # Shoulders
+ (11, 13), (13, 15), (15, 17), (17, 19), (15, 19), (15, 21), # Left Arm and simplified Left Hand (wrist to fingers)
+ (12, 14), (14, 16), (16, 18), (18, 20), (16, 20), (16, 22), # Right Arm and simplified Right Hand (wrist to fingers)
+ (11, 23), (12, 24), # Connect shoulders to Hips
+ (23, 24), # Hip connection
+ (23, 25), (25, 27), (27, 29), (29, 31), (27, 31), # Left Leg and Foot
+ (24, 26), (26, 28), (28, 30), (30, 32), (28, 32) # Right Leg and Foot
+]
+
+face_skeleton = [] # Draw face dots only, no connections
+
+# MediaPipe Hand connections (21 landmarks per hand, indices 0-20)
+hand_skeleton = [
+ (0, 1), (1, 2), (2, 3), (3, 4), # Thumb
+ (0, 5), (5, 6), (6, 7), (7, 8), # Index finger
+ (0, 9), (9, 10), (10, 11), (11, 12), # Middle finger
+ (0, 13), (13, 14), (14, 15), (15, 16), # Ring finger
+ (0, 17), (17, 18), (18, 19), (19, 20) # Pinky finger
+]
+
+def draw_keypoints_and_skeleton(image, keypoints_data, skeleton_connections, colors_config, confidence_threshold=0.1, point_radius=3, line_thickness=2, is_face=False, is_hand=False):
+ if not keypoints_data:
+ return
+
+ tri_tuples = []
+ if isinstance(keypoints_data, list) and len(keypoints_data) > 0 and isinstance(keypoints_data[0], (int, float)) and len(keypoints_data) % 3 == 0:
+ for i in range(0, len(keypoints_data), 3):
+ tri_tuples.append(keypoints_data[i:i+3])
+ else:
+ dprint(f"draw_keypoints_and_skeleton: Unexpected keypoints_data format or length not divisible by 3. Data: {keypoints_data}")
+ return
+
+ if skeleton_connections:
+ for i, (joint_idx_a, joint_idx_b) in enumerate(skeleton_connections):
+ if joint_idx_a >= len(tri_tuples) or joint_idx_b >= len(tri_tuples):
+ continue
+ a_x, a_y, a_confidence = tri_tuples[joint_idx_a]
+ b_x, b_y, b_confidence = tri_tuples[joint_idx_b]
+
+ if a_confidence >= confidence_threshold and b_confidence >= confidence_threshold:
+ limb_color = None
+ if is_hand:
+ limb_color_list = colors_config['limbs']
+ limb_color = limb_color_list[i % len(limb_color_list)]
+ else:
+ limb_color_list = colors_config if isinstance(colors_config, list) else [colors_config]
+ limb_color = limb_color_list[i % len(limb_color_list)]
+ if limb_color is not None:
+ cv2.line(image, (int(a_x), int(a_y)), (int(b_x), int(b_y)), limb_color, line_thickness)
+
+ for i, (x, y, confidence) in enumerate(tri_tuples):
+ if confidence >= confidence_threshold:
+ point_color = None
+ current_radius = point_radius
+ if is_hand:
+ point_color = colors_config['points']
+ elif is_face:
+ point_color = colors_config
+ current_radius = 2
+ else:
+ point_color_list = colors_config
+ point_color = point_color_list[i % len(point_color_list)]
+ if point_color is not None:
+ cv2.circle(image, (int(x), int(y)), current_radius, point_color, -1)
+
+def gen_skeleton_with_face_hands(pose_keypoints_2d, face_keypoints_2d, hand_left_keypoints_2d, hand_right_keypoints_2d,
+ canvas_width, canvas_height, landmarkType, confidence_threshold=0.1):
+ image = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)
+
+ def scale_keypoints(keypoints, target_w, target_h, input_is_normalized):
+ if not keypoints: return []
+ scaled = []
+ if not isinstance(keypoints, list) or (keypoints and not isinstance(keypoints[0], (int, float))):
+ dprint(f"scale_keypoints: Unexpected keypoints format: {type(keypoints)}. Expecting flat list of numbers.")
+ return []
+
+ for i in range(0, len(keypoints), 3):
+ x, y, conf = keypoints[i:i+3]
+ if input_is_normalized: scaled.extend([x * target_w, y * target_h, conf])
+ else: scaled.extend([x, y, conf])
+ return scaled
+
+ input_is_normalized = (landmarkType == "OpenPose") # This might need adjustment based on actual landmarkType usage
+
+ scaled_pose = scale_keypoints(pose_keypoints_2d, canvas_width, canvas_height, input_is_normalized)
+ scaled_face = scale_keypoints(face_keypoints_2d, canvas_width, canvas_height, input_is_normalized)
+ scaled_hand_left = scale_keypoints(hand_left_keypoints_2d, canvas_width, canvas_height, input_is_normalized)
+ scaled_hand_right = scale_keypoints(hand_right_keypoints_2d, canvas_width, canvas_height, input_is_normalized)
+
+ draw_keypoints_and_skeleton(image, scaled_pose, body_skeleton, body_colors, confidence_threshold, point_radius=6, line_thickness=4)
+ if scaled_face:
+ draw_keypoints_and_skeleton(image, scaled_face, face_skeleton, face_color, confidence_threshold, point_radius=2, line_thickness=1, is_face=True)
+ hand_colors_config = {'limbs': hand_limb_colors, 'points': hand_keypoint_color}
+ if scaled_hand_left:
+ draw_keypoints_and_skeleton(image, scaled_hand_left, hand_skeleton, hand_colors_config, confidence_threshold, point_radius=3, line_thickness=2, is_hand=True)
+ if scaled_hand_right:
+ draw_keypoints_and_skeleton(image, scaled_hand_right, hand_skeleton, hand_colors_config, confidence_threshold, point_radius=3, line_thickness=2, is_hand=True)
+ return image
+
+def transform_all_keypoints(keypoints_1_dict, keypoints_2_dict, frames, interpolation="linear"):
+ def interpolate_keypoint_set(kp1_list, kp2_list, num_frames, interp_method):
+ if not kp1_list and not kp2_list: return [[] for _ in range(num_frames)]
+
+ len1 = len(kp1_list) if kp1_list else 0
+ len2 = len(kp2_list) if kp2_list else 0
+
+ if not kp1_list: kp1_list = [0.0] * len2
+ if not kp2_list: kp2_list = [0.0] * len1
+
+ if len(kp1_list) != len(kp2_list) or not kp1_list or len(kp1_list) % 3 != 0:
+ dprint(f"interpolate_keypoint_set: Mismatched, empty, or non-triplet keypoint lists after padding. KP1 len: {len(kp1_list)}, KP2 len: {len(kp2_list)}. Returning empty sequences.")
+ return [[] for _ in range(num_frames)]
+
+ tri_tuples_1 = [kp1_list[i:i + 3] for i in range(0, len(kp1_list), 3)]
+ tri_tuples_2 = [kp2_list[i:i + 3] for i in range(0, len(kp2_list), 3)]
+
+ keypoints_sequence = []
+ for j in range(num_frames):
+ interpolated_kps_for_frame = []
+ t = j / float(num_frames - 1) if num_frames > 1 else 0.0
+
+ interp_factor = t
+ if interp_method == "ease-in": interp_factor = t * t
+ elif interp_method == "ease-out": interp_factor = 1 - (1 - t) * (1 - t)
+ elif interp_method == "ease-in-out":
+ if t < 0.5: interp_factor = 2 * t * t
+ else: interp_factor = 1 - pow(-2 * t + 2, 2) / 2
+
+ for i in range(len(tri_tuples_1)):
+ x1, y1, c1 = tri_tuples_1[i]
+ x2, y2, c2 = tri_tuples_2[i]
+ new_x, new_y, new_c = 0.0, 0.0, 0.0
+
+ if c1 > 0.05 and c2 > 0.05:
+ new_x = x1 + (x2 - x1) * interp_factor
+ new_y = y1 + (y2 - y1) * interp_factor
+ new_c = c1 + (c2 - c1) * interp_factor
+ elif c1 > 0.05 and c2 <= 0.05:
+ new_x, new_y = x1, y1
+ new_c = c1 * (1.0 - interp_factor)
+ elif c1 <= 0.05 and c2 > 0.05:
+ new_x, new_y = x2, y2
+ new_c = c2 * interp_factor
+ interpolated_kps_for_frame.extend([new_x, new_y, new_c])
+ keypoints_sequence.append(interpolated_kps_for_frame)
+ return keypoints_sequence
+
+ pose_1 = keypoints_1_dict.get('pose_keypoints_2d', [])
+ face_1 = keypoints_1_dict.get('face_keypoints_2d', [])
+ hand_left_1 = keypoints_1_dict.get('hand_left_keypoints_2d', [])
+ hand_right_1 = keypoints_1_dict.get('hand_right_keypoints_2d', [])
+
+ pose_2 = keypoints_2_dict.get('pose_keypoints_2d', [])
+ face_2 = keypoints_2_dict.get('face_keypoints_2d', [])
+ hand_left_2 = keypoints_2_dict.get('hand_left_keypoints_2d', [])
+ hand_right_2 = keypoints_2_dict.get('hand_right_keypoints_2d', [])
+
+ pose_sequence = interpolate_keypoint_set(pose_1, pose_2, frames, interpolation)
+ face_sequence = interpolate_keypoint_set(face_1, face_2, frames, interpolation)
+ hand_left_sequence = interpolate_keypoint_set(hand_left_1, hand_left_2, frames, interpolation)
+ hand_right_sequence = interpolate_keypoint_set(hand_right_1, hand_right_2, frames, interpolation)
+
+ combined_sequence = []
+ for i in range(frames):
+ combined_frame_data = {
+ 'pose_keypoints_2d': pose_sequence[i] if i < len(pose_sequence) else [],
+ 'face_keypoints_2d': face_sequence[i] if i < len(face_sequence) else [],
+ 'hand_left_keypoints_2d': hand_left_sequence[i] if i < len(hand_left_sequence) else [],
+ 'hand_right_keypoints_2d': hand_right_sequence[i] if i < len(hand_right_sequence) else []
+ }
+ combined_sequence.append(combined_frame_data)
+ return combined_sequence
+
+def extract_pose_keypoints(image_path: str | Path, include_face=True, include_hands=True, resolution: tuple[int,int]=(640,480)) -> dict:
+ # import mediapipe as mp # Already imported at top
+ # import cv2 # Already imported at top
+ image = cv2.imread(str(image_path))
+ if image is None:
+ raise ValueError(f"Failed to load image: {image_path}")
+
+ # Resize image before processing if its resolution differs significantly,
+ # or ensure MediaPipe processes at a consistent internal resolution.
+ # For now, assume MediaPipe handles various input sizes well, and keypoints are normalized.
+ # The passed 'resolution' param will be used to scale normalized keypoints back to absolute.
+
+ height, width = resolution[1], resolution[0] # For scaling output coords
+
+ mp_holistic = mp.solutions.holistic
+ # It's good practice to use a try-finally for resources like MediaPipe Holistic
+ holistic_instance = mp_holistic.Holistic(static_image_mode=True,
+ min_detection_confidence=0.5,
+ min_tracking_confidence=0.5)
+ try:
+ # Convert BGR image to RGB for MediaPipe
+ results = holistic_instance.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
+ finally:
+ holistic_instance.close()
+
+ keypoints = {}
+ pose_kps = []
+ if results.pose_landmarks:
+ for lm in results.pose_landmarks.landmark:
+ # lm.x, lm.y are normalized coordinates; scale them to target resolution
+ pose_kps.extend([lm.x * width, lm.y * height, lm.visibility])
+ keypoints['pose_keypoints_2d'] = pose_kps
+
+ face_kps = []
+ if include_face and results.face_landmarks:
+ for lm in results.face_landmarks.landmark:
+ face_kps.extend([lm.x * width, lm.y * height, lm.visibility if hasattr(lm, 'visibility') else 1.0]) # Some face landmarks might not have visibility
+ keypoints['face_keypoints_2d'] = face_kps
+
+ left_hand_kps = []
+ if include_hands and results.left_hand_landmarks:
+ for lm in results.left_hand_landmarks.landmark:
+ left_hand_kps.extend([lm.x * width, lm.y * height, lm.visibility])
+ keypoints['hand_left_keypoints_2d'] = left_hand_kps
+
+ right_hand_kps = []
+ if include_hands and results.right_hand_landmarks:
+ for lm in results.right_hand_landmarks.landmark:
+ right_hand_kps.extend([lm.x * width, lm.y * height, lm.visibility])
+ keypoints['hand_right_keypoints_2d'] = right_hand_kps
+
+ return keypoints
+
+def create_pose_interpolated_guide_video(output_video_path: str | Path, resolution: tuple[int, int], total_frames: int,
+ start_image_path: str | Path, end_image_path: str | Path,
+ interpolation="linear", confidence_threshold=0.1,
+ include_face=True, include_hands=True, fps=25):
+ dprint(f"Creating pose interpolated guide: {output_video_path} from '{Path(start_image_path).name}' to '{Path(end_image_path).name}' ({total_frames} frames). First frame will be actual start image.")
+
+ if total_frames <= 0:
+ dprint(f"Video creation skipped for {output_video_path} as total_frames is {total_frames}.")
+ return
+
+ frames_list = []
+ canvas_width, canvas_height = resolution
+
+ first_visual_frame_np = image_to_frame(start_image_path, resolution)
+ if first_visual_frame_np is None:
+ print(f"Error loading start image {start_image_path} for guide video frame 0. Using black frame.")
+ traceback.print_exc()
+ first_visual_frame_np = create_color_frame(resolution, (0,0,0))
+ frames_list.append(first_visual_frame_np)
+
+ if total_frames > 1:
+ try:
+ # Pass the target resolution for keypoint scaling
+ keypoints_from = extract_pose_keypoints(start_image_path, include_face, include_hands, resolution)
+ keypoints_to = extract_pose_keypoints(end_image_path, include_face, include_hands, resolution)
+ except Exception as e_extract:
+ print(f"Error extracting keypoints for pose interpolation: {e_extract}. Filling remaining guide frames with black.")
+ traceback.print_exc()
+ black_frame = create_color_frame(resolution, (0,0,0))
+ for _ in range(total_frames - 1):
+ frames_list.append(black_frame)
+ create_video_from_frames_list(frames_list, output_video_path, fps, resolution)
+ return
+
+ interpolated_sequence = transform_all_keypoints(keypoints_from, keypoints_to, total_frames, interpolation)
+
+ # landmarkType for gen_skeleton_with_face_hands should indicate absolute coordinates
+ # as extract_pose_keypoints now returns absolute coordinates scaled to 'resolution'
+ landmark_type_for_gen = "AbsoluteCoords"
+
+ for i in range(1, total_frames):
+ if i < len(interpolated_sequence):
+ frame_data = interpolated_sequence[i]
+ pose_kps = frame_data.get('pose_keypoints_2d', [])
+ face_kps = frame_data.get('face_keypoints_2d', []) if include_face else []
+ hand_left_kps = frame_data.get('hand_left_keypoints_2d', []) if include_hands else []
+ hand_right_kps = frame_data.get('hand_right_keypoints_2d', []) if include_hands else []
+
+ img = gen_skeleton_with_face_hands(
+ pose_kps, face_kps, hand_left_kps, hand_right_kps,
+ canvas_width, canvas_height,
+ landmark_type_for_gen, # Keypoints are already absolute
+ confidence_threshold
+ )
+ frames_list.append(img)
+ else:
+ dprint(f"Warning: Interpolated sequence too short at index {i} for {output_video_path}. Appending black frame.")
+ frames_list.append(create_color_frame(resolution, (0,0,0)))
+
+ if len(frames_list) != total_frames:
+ dprint(f"Warning: Generated {len(frames_list)} frames for {output_video_path}, expected {total_frames}. Adjusting.")
+ if len(frames_list) < total_frames:
+ last_frame = frames_list[-1] if frames_list else create_color_frame(resolution, (0,0,0))
+ frames_list.extend([last_frame.copy() for _ in range(total_frames - len(frames_list))])
+ else:
+ frames_list = frames_list[:total_frames]
+
+ if not frames_list:
+ dprint(f"Error: No frames for video {output_video_path}. Skipping creation.")
+ return
+
+ create_video_from_frames_list(frames_list, output_video_path, fps, resolution)
+
+# --- Debug Summary Video Helpers ---
+def get_resized_frame(video_path_str: str, target_size: tuple[int, int], frame_ratio: float = 0.5) -> np.ndarray | None:
+ """Extracts a frame (by ratio, e.g., 0.5 for middle) from a video and resizes it."""
+ video_path = Path(video_path_str)
+ if not video_path.exists() or video_path.stat().st_size == 0:
+ dprint(f"GET_RESIZED_FRAME: Video not found or empty: {video_path_str}")
+ placeholder = create_color_frame(target_size, (10, 10, 10)) # Dark grey
+ cv2.putText(placeholder, "Not Found", (10, target_size[1] // 2), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
+ return placeholder
+
+ cap = None
+ try:
+ cap = cv2.VideoCapture(str(video_path))
+ if not cap.isOpened():
+ dprint(f"GET_RESIZED_FRAME: Could not open video: {video_path_str}")
+ return create_color_frame(target_size, (20,20,20))
+
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ if total_frames == 0:
+ dprint(f"GET_RESIZED_FRAME: Video has 0 frames: {video_path_str}")
+ return create_color_frame(target_size, (30,30,30))
+
+ frame_to_get = int(total_frames * frame_ratio)
+ frame_to_get = max(0, min(frame_to_get, total_frames - 1)) # Clamp
+
+ cap.set(cv2.CAP_PROP_POS_FRAMES, float(frame_to_get))
+ ret, frame = cap.read()
+ if not ret or frame is None:
+ dprint(f"GET_RESIZED_FRAME: Could not read frame {frame_to_get} from: {video_path_str}")
+ return create_color_frame(target_size, (40,40,40))
+
+ return cv2.resize(frame, target_size, interpolation=cv2.INTER_AREA)
+ except Exception as e:
+ dprint(f"GET_RESIZED_FRAME: Exception processing {video_path_str}: {e}")
+ return create_color_frame(target_size, (50,50,50)) # Error color
+ finally:
+ if cap: cap.release()
+
+def draw_multiline_text(image, text_lines, start_pos, font, font_scale, color, thickness, line_spacing):
+ x, y = start_pos
+ for i, line in enumerate(text_lines):
+ line_y = y + (i * (cv2.getTextSize(line, font, font_scale, thickness)[0][1] + line_spacing))
+ cv2.putText(image, line, (x, line_y), font, font_scale, color, thickness, cv2.LINE_AA)
+ return image
+
+def generate_debug_summary_video(segments_data: list[dict], output_path: str | Path, fps: int,
+ num_frames_for_collage: int,
+ target_thumb_size: tuple[int, int] = (320, 180)):
+ if not DEBUG_MODE: return # Only run if debug mode is on
+ if not segments_data:
+ dprint("GENERATE_DEBUG_SUMMARY_VIDEO: No segment data provided.")
+ return
+
+ dprint(f"Generating animated debug collage with {num_frames_for_collage} frames, at {fps} FPS.")
+
+ thumb_w, thumb_h = target_thumb_size
+ padding = 10
+ header_h = 50
+ text_line_h_approx = 20
+ max_settings_lines = 6
+ settings_area_h = (text_line_h_approx * max_settings_lines) + padding
+
+ num_segments = len(segments_data)
+ col_w = thumb_w + (2 * padding)
+ canvas_w = num_segments * col_w
+ canvas_h = header_h + (thumb_h * 2) + (padding * 3) + settings_area_h + padding
+
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ font_scale_small = 0.4
+ font_scale_title = 0.6
+ text_color = (230, 230, 230)
+ title_color = (255, 255, 255)
+ line_thickness = 1
+
+ overall_static_template_canvas = np.full((canvas_h, canvas_w, 3), (30, 30, 30), dtype=np.uint8)
+ for idx, seg_data in enumerate(segments_data):
+ col_x_start = idx * col_w
+ center_x_col = col_x_start + col_w // 2
+ title_text = f"Segment {seg_data['segment_index']}"
+ (tw, th), _ = cv2.getTextSize(title_text, font, font_scale_title, line_thickness)
+ cv2.putText(overall_static_template_canvas, title_text, (center_x_col - tw//2, header_h - padding), font, font_scale_title, title_color, line_thickness, cv2.LINE_AA)
+
+ y_offset = header_h
+ cv2.putText(overall_static_template_canvas, "Input Guide", (col_x_start + padding, y_offset + text_line_h_approx), font, font_scale_small, text_color, line_thickness)
+ y_offset += thumb_h + padding
+ cv2.putText(overall_static_template_canvas, "Headless Output", (col_x_start + padding, y_offset + text_line_h_approx), font, font_scale_small, text_color, line_thickness)
+ y_offset += thumb_h + padding
+
+ settings_y_start = y_offset
+ cv2.putText(overall_static_template_canvas, "Settings:", (col_x_start + padding, settings_y_start + text_line_h_approx), font, font_scale_small, text_color, line_thickness)
+ settings_text_lines = []
+ payload = seg_data.get("task_payload", {})
+ settings_text_lines.append(f"Task ID: {payload.get('task_id', 'N/A')[:10]}...")
+ prompt_short = payload.get('prompt', 'N/A')[:35] + ("..." if len(payload.get('prompt', '')) > 35 else "")
+ settings_text_lines.append(f"Prompt: {prompt_short}")
+ settings_text_lines.append(f"Seed: {payload.get('seed', 'N/A')}, Frames: {payload.get('frames', 'N/A')}")
+ settings_text_lines.append(f"Resolution: {payload.get('resolution', 'N/A')}")
+ draw_multiline_text(overall_static_template_canvas, settings_text_lines[:max_settings_lines],
+ (col_x_start + padding, settings_y_start + text_line_h_approx + padding),
+ font, font_scale_small, text_color, line_thickness, 5)
+
+ error_placeholder_frame = create_color_frame(target_thumb_size, (50, 0, 0))
+ cv2.putText(error_placeholder_frame, "ERR", (10, target_thumb_size[1]//2), font, 0.8, (255,255,255), 1)
+ not_found_placeholder_frame = create_color_frame(target_thumb_size, (0, 50, 0))
+ cv2.putText(not_found_placeholder_frame, "N/A", (10, target_thumb_size[1]//2), font, 0.8, (255,255,255), 1)
+ static_thumbs_cache = {}
+ for seg_idx_cache, seg_data_cache in enumerate(segments_data):
+ guide_thumb = get_resized_frame(seg_data_cache["guide_video_path"], target_thumb_size, frame_ratio=0.5)
+ output_thumb = get_resized_frame(seg_data_cache["raw_headless_output_path"], target_thumb_size, frame_ratio=0.5)
+
+ static_thumbs_cache[seg_idx_cache] = {
+ 'guide': guide_thumb if guide_thumb is not None else not_found_placeholder_frame,
+ 'output': output_thumb if output_thumb is not None else not_found_placeholder_frame
+ }
+
+ writer = None
+ try:
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ writer = cv2.VideoWriter(str(output_path), fourcc, float(fps), (canvas_w, canvas_h))
+ if not writer.isOpened():
+ dprint(f"GENERATE_DEBUG_SUMMARY_VIDEO: Failed to open VideoWriter for {output_path}")
+ return
+
+ dprint(f"GENERATE_DEBUG_SUMMARY_VIDEO: Writing sequentially animated collage to {output_path}")
+
+ for active_seg_idx in range(num_segments):
+ dprint(f"Animating segment {active_seg_idx} in collage...")
+ caps_for_active_segment = {'guide': None, 'output': None, 'last_frames': {}}
+ video_paths_to_load = {
+ 'guide': segments_data[active_seg_idx]["guide_video_path"],
+ 'output': segments_data[active_seg_idx]["raw_headless_output_path"]
+ }
+ for key, path_str in video_paths_to_load.items():
+ p = Path(path_str)
+ if p.exists() and p.stat().st_size > 0:
+ cap_video = cv2.VideoCapture(str(p))
+ if cap_video.isOpened():
+ caps_for_active_segment[key] = cap_video
+ ret, frame = cap_video.read();
+ caps_for_active_segment['last_frames'][key] = cv2.resize(frame, target_thumb_size, cv2.INTER_AREA) if ret and frame is not None else error_placeholder_frame
+ cap_video.set(cv2.CAP_PROP_POS_FRAMES, 0.0)
+ else: caps_for_active_segment['last_frames'][key] = error_placeholder_frame
+ else: caps_for_active_segment['last_frames'][key] = not_found_placeholder_frame
+
+ for frame_num in range(num_frames_for_collage):
+ current_frame_canvas = overall_static_template_canvas.copy()
+
+ for display_seg_idx in range(num_segments):
+ col_x_start = display_seg_idx * col_w
+ current_y_pos = header_h
+
+ videos_to_composite = [None, None] # guide, output
+
+ if display_seg_idx == active_seg_idx:
+ if caps_for_active_segment['guide']:
+ ret, frame = caps_for_active_segment['guide'].read()
+ if ret and frame is not None: videos_to_composite[0] = cv2.resize(frame, target_thumb_size); caps_for_active_segment['last_frames']['guide'] = videos_to_composite[0]
+ else: videos_to_composite[0] = caps_for_active_segment['last_frames'].get('guide', error_placeholder_frame)
+ else: videos_to_composite[0] = caps_for_active_segment['last_frames'].get('guide', not_found_placeholder_frame)
+ if caps_for_active_segment['output']:
+ ret, frame = caps_for_active_segment['output'].read()
+ if ret and frame is not None: videos_to_composite[1] = cv2.resize(frame, target_thumb_size); caps_for_active_segment['last_frames']['output'] = videos_to_composite[1]
+ else: videos_to_composite[1] = caps_for_active_segment['last_frames'].get('output', error_placeholder_frame)
+ else: videos_to_composite[1] = caps_for_active_segment['last_frames'].get('output', not_found_placeholder_frame)
+ else:
+ videos_to_composite[0] = static_thumbs_cache[display_seg_idx]['guide']
+ videos_to_composite[1] = static_thumbs_cache[display_seg_idx]['output']
+
+ current_frame_canvas[current_y_pos : current_y_pos + thumb_h, col_x_start + padding : col_x_start + padding + thumb_w] = videos_to_composite[0]
+ current_y_pos += thumb_h + padding
+ current_frame_canvas[current_y_pos : current_y_pos + thumb_h, col_x_start + padding : col_x_start + padding + thumb_w] = videos_to_composite[1]
+
+ writer.write(current_frame_canvas)
+
+ if caps_for_active_segment['guide']: caps_for_active_segment['guide'].release()
+ if caps_for_active_segment['output']: caps_for_active_segment['output'].release()
+ dprint(f"Finished animating segment {active_seg_idx} in collage.")
+
+ dprint(f"GENERATE_DEBUG_SUMMARY_VIDEO: Finished writing sequentially animated debug collage.")
+
+ except Exception as e:
+ dprint(f"GENERATE_DEBUG_SUMMARY_VIDEO: Exception during video writing: {e} - {traceback.format_exc()}")
+ finally:
+ if writer: writer.release()
+ dprint("GENERATE_DEBUG_SUMMARY_VIDEO: Video writer released.")
+
+
+def generate_different_perspective_debug_video_summary(
+ video_stage_data: list[dict],
+ output_path: Path,
+ fps: int,
+ target_resolution: tuple[int, int] # width, height
+):
+ if not DEBUG_MODE: return # Only run if debug mode is on
+ dprint(f"Generating different_perspective DEBUG VIDEO summary at {output_path} ({fps} FPS, {target_resolution[0]}x{target_resolution[1]})")
+
+ all_output_frames = []
+ font_pil = None
+ try:
+ pil_font = ImageFont.truetype("arial.ttf", size=24)
+ except IOError:
+ pil_font = ImageFont.load_default()
+ dprint("Arial font not found for debug video summary, using default PIL font.")
+
+ text_color = (255, 255, 255)
+ bg_color = (0,0,0)
+ text_bg_opacity = 128
+
+ for stage_info in video_stage_data:
+ label = stage_info.get('label', 'Unknown Stage')
+ file_type = stage_info.get('type', 'image')
+ file_path_str = stage_info.get('path')
+ display_duration_frames = stage_info.get('display_frames', fps * 2)
+
+ if not file_path_str:
+ print(f"[Debug Video] Missing path for stage '{label}', skipping.")
+ continue
+
+ file_path = Path(file_path_str)
+ if not file_path.exists():
+ print(f"[Debug Video] File not found for stage '{label}': {file_path}, creating placeholder frames.")
+ placeholder_frame_np = create_color_frame(target_resolution, (50, 0, 0))
+ placeholder_pil = Image.fromarray(cv2.cvtColor(placeholder_frame_np, cv2.COLOR_BGR2RGB))
+ draw = ImageDraw.Draw(placeholder_pil)
+ draw.text((20, 20), f"{label}\n(File Not Found)", font=pil_font, fill=text_color)
+ placeholder_frame_final_np = cv2.cvtColor(np.array(placeholder_pil), cv2.COLOR_RGB2BGR)
+ all_output_frames.extend([placeholder_frame_final_np] * display_duration_frames)
+ continue
+
+ current_stage_frames_np = []
+ try:
+ if file_type == 'image':
+ pil_img = Image.open(file_path).convert("RGB")
+ pil_img_resized = pil_img.resize(target_resolution, Image.Resampling.LANCZOS)
+ np_bgr_frame = cv2.cvtColor(np.array(pil_img_resized), cv2.COLOR_RGB2BGR)
+ current_stage_frames_np = [np_bgr_frame] * display_duration_frames
+
+ elif file_type == 'video':
+ cap_video = cv2.VideoCapture(str(file_path))
+ if not cap_video.isOpened():
+ raise IOError(f"Could not open video: {file_path}")
+
+ frames_read = 0
+ while frames_read < display_duration_frames:
+ ret, frame_np = cap_video.read()
+ if not ret:
+ if current_stage_frames_np:
+ current_stage_frames_np.extend([current_stage_frames_np[-1]] * (display_duration_frames - frames_read))
+ else:
+ err_frame = create_color_frame(target_resolution, (0,50,0))
+ err_pil = Image.fromarray(cv2.cvtColor(err_frame, cv2.COLOR_BGR2RGB))
+ ImageDraw.Draw(err_pil).text((20,20), f"{label}\n(Video Read Error)", font=pil_font, fill=text_color)
+ current_stage_frames_np.extend([cv2.cvtColor(np.array(err_pil), cv2.COLOR_RGB2BGR)] * (display_duration_frames - frames_read))
+ break
+
+ if frame_np.shape[1] != target_resolution[0] or frame_np.shape[0] != target_resolution[1]:
+ frame_np = cv2.resize(frame_np, target_resolution, interpolation=cv2.INTER_AREA)
+ current_stage_frames_np.append(frame_np)
+ frames_read += 1
+ cap_video.release()
+ else:
+ print(f"[Debug Video] Unknown file type '{file_type}' for stage '{label}'. Skipping.")
+ continue
+
+ for i in range(len(current_stage_frames_np)):
+ frame_pil = Image.fromarray(cv2.cvtColor(current_stage_frames_np[i], cv2.COLOR_BGR2RGB))
+ draw = ImageDraw.Draw(frame_pil, 'RGBA')
+
+ text_x, text_y = 20, 20
+ bbox = draw.textbbox((text_x, text_y), label, font=pil_font)
+ rect_coords = [(bbox[0]-5, bbox[1]-5), (bbox[2]+5, bbox[3]+5)]
+ draw.rectangle(rect_coords, fill=(bg_color[0], bg_color[1], bg_color[2], text_bg_opacity))
+ draw.text((text_x, text_y), label, font=pil_font, fill=text_color)
+
+ current_stage_frames_np[i] = cv2.cvtColor(np.array(frame_pil), cv2.COLOR_RGB2BGR)
+
+ all_output_frames.extend(current_stage_frames_np)
+
+ except Exception as e_stage:
+ print(f"[Debug Video] Error processing stage '{label}' (path: {file_path}): {e_stage}")
+ traceback.print_exc()
+ err_frame_np = create_color_frame(target_resolution, (0,0,50))
+ err_pil = Image.fromarray(cv2.cvtColor(err_frame_np, cv2.COLOR_BGR2RGB))
+ ImageDraw.Draw(err_pil).text((20,20), f"{label}\n(Stage Processing Error)", font=pil_font, fill=text_color)
+ all_output_frames.extend([cv2.cvtColor(np.array(err_pil), cv2.COLOR_RGB2BGR)] * display_duration_frames)
+
+ if not all_output_frames:
+ dprint("[Debug Video] No frames were generated for the debug video summary.")
+ return
+
+ print(f"[Debug Video] Creating final video with {len(all_output_frames)} frames.")
+ create_video_from_frames_list(all_output_frames, output_path, fps, target_resolution)
+ print(f"Debug video summary for 'different_perspective' saved to: {output_path.resolve()}")
+
+def download_file(url, dest_folder, filename):
+ dest_path = Path(dest_folder) / filename
+ if dest_path.exists():
+ # Validate existing file before assuming it's good
+ if filename.endswith('.safetensors') or 'lora' in filename.lower():
+ is_valid, validation_msg = validate_lora_file(dest_path, filename)
+ if is_valid:
+ print(f"[INFO] File {filename} already exists and is valid in {dest_folder}. {validation_msg}")
+ return True
+ else:
+ print(f"[WARNING] Existing file {filename} failed validation ({validation_msg}). Re-downloading...")
+ dest_path.unlink()
+ else:
+ print(f"[INFO] File {filename} already exists in {dest_folder}.")
+ return True
+
+ # Use huggingface_hub for HuggingFace URLs for better reliability
+ if "huggingface.co" in url:
+ try:
+ from huggingface_hub import hf_hub_download
+ from urllib.parse import urlparse
+
+ # Parse HuggingFace URL to extract repo_id and filename
+ # Format: https://huggingface.co/USER/REPO/resolve/BRANCH/FILENAME
+ parsed = urlparse(url)
+ path_parts = parsed.path.strip('/').split('/')
+
+ if len(path_parts) >= 4 and path_parts[2] == 'resolve':
+ repo_id = f"{path_parts[0]}/{path_parts[1]}"
+ branch = path_parts[3] if len(path_parts) > 4 else "main"
+ hf_filename = '/'.join(path_parts[4:]) if len(path_parts) > 4 else filename
+
+ print(f"Downloading {filename} from HuggingFace repo {repo_id} using hf_hub_download...")
+
+ # Download using huggingface_hub with automatic checksums and resumption
+ downloaded_path = hf_hub_download(
+ repo_id=repo_id,
+ filename=hf_filename,
+ revision=branch,
+ cache_dir=str(dest_folder),
+ resume_download=True,
+ local_files_only=False
+ )
+
+ # Copy from HF cache to target location if different
+ if Path(downloaded_path) != dest_path:
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
+ import shutil
+ shutil.copy2(downloaded_path, dest_path)
+
+ # Validate the downloaded file
+ if filename.endswith('.safetensors') or 'lora' in filename.lower():
+ is_valid, validation_msg = validate_lora_file(dest_path, filename)
+ if not is_valid:
+ print(f"[ERROR] Downloaded file {filename} failed validation: {validation_msg}")
+ dest_path.unlink(missing_ok=True)
+ return False
+ print(f"Successfully downloaded and validated {filename}. {validation_msg}")
+ else:
+ print(f"Successfully downloaded {filename} with integrity verification.")
+ return True
+
+ except ImportError:
+ print(f"[WARNING] huggingface_hub not available, falling back to requests for {url}")
+ except Exception as e:
+ print(f"[WARNING] HuggingFace download failed for {filename}: {e}, falling back to requests")
+
+ # Fallback to requests with basic integrity checks
+ try:
+ print(f"Downloading {filename} from {url} to {dest_folder}...")
+ response = requests.get(url, stream=True)
+ response.raise_for_status() # Raise an exception for HTTP errors
+
+ # Get expected content length for verification
+ expected_size = int(response.headers.get('content-length', 0))
+
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(dest_path, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+ downloaded_size += len(chunk)
+
+ # Verify download integrity
+ actual_size = dest_path.stat().st_size
+ if expected_size > 0 and actual_size != expected_size:
+ print(f"[ERROR] Size mismatch for {filename}: expected {expected_size}, got {actual_size}")
+ dest_path.unlink(missing_ok=True)
+ return False
+
+ # Use comprehensive validation for LoRA files
+ if filename.endswith('.safetensors') or 'lora' in filename.lower():
+ is_valid, validation_msg = validate_lora_file(dest_path, filename)
+ if not is_valid:
+ print(f"[ERROR] Downloaded file {filename} failed validation: {validation_msg}")
+ dest_path.unlink(missing_ok=True)
+ return False
+ print(f"Successfully downloaded and validated {filename}. {validation_msg}")
+ else:
+ # For non-LoRA safetensors files, do basic format check
+ if filename.endswith('.safetensors'):
+ try:
+ import safetensors.torch as st
+ with st.safe_open(dest_path, framework="pt") as f:
+ pass # Just verify it can be opened
+ print(f"Successfully downloaded and verified safetensors file {filename}.")
+ except ImportError:
+ print(f"[WARNING] safetensors not available for verification of {filename}")
+ except Exception as e:
+ print(f"[ERROR] Downloaded safetensors file {filename} appears corrupted: {e}")
+ dest_path.unlink(missing_ok=True)
+ return False
+ else:
+ print(f"Successfully downloaded {filename}.")
+
+ return True
+
+ except Exception as e:
+ print(f"[ERROR] Failed to download {filename}: {e}")
+ if dest_path.exists(): # Attempt to clean up partial download
+ try: os.remove(dest_path)
+ except: pass
+ return False
+
+# Added to provide a unique target path generator for files.
+def _get_unique_target_path(target_dir: Path, base_name: str, extension: str) -> Path:
+ """Generates a unique target Path in the given directory by appending a timestamp and random string."""
+ target_dir = Path(target_dir)
+ target_dir.mkdir(parents=True, exist_ok=True)
+ timestamp_short = datetime.now().strftime("%H%M%S")
+ # Use a short UUID/random string to significantly reduce collision probability with just timestamp
+ unique_suffix = uuid.uuid4().hex[:6]
+
+ # Construct the filename
+ # Ensure extension has a leading dot
+ if extension and not extension.startswith('.'):
+ extension = '.' + extension
+
+ filename = f"{base_name}_{timestamp_short}_{unique_suffix}{extension}"
+ return target_dir / filename
+
+def download_image_if_url(image_url_or_path: str, download_target_dir: Path | str | None, task_id_for_logging: str | None = "generic_task") -> str:
+ """
+ Checks if the given string is an HTTP/HTTPS URL. If so, and if download_target_dir is provided,
+ downloads the image to a unique path within download_target_dir.
+ Returns the local file path string if downloaded, otherwise returns the original string.
+ """
+ if not image_url_or_path:
+ return image_url_or_path
+
+ parsed_url = urlparse(image_url_or_path)
+ if parsed_url.scheme in ['http', 'https'] and download_target_dir:
+ target_dir_path = Path(download_target_dir)
+ try:
+ target_dir_path.mkdir(parents=True, exist_ok=True)
+ dprint(f"Task {task_id_for_logging}: Downloading image from URL: {image_url_or_path} to {target_dir_path.resolve()}")
+
+ # Use a session for potential keep-alive and connection pooling
+ with requests.Session() as s:
+ response = s.get(image_url_or_path, stream=True, timeout=300) # 5 min timeout
+ response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
+
+ original_filename = Path(parsed_url.path).name
+ original_suffix = Path(original_filename).suffix if Path(original_filename).suffix else ".jpg" # Default to .jpg if no suffix
+ if not original_suffix.startswith('.'):
+ original_suffix = '.' + original_suffix
+
+ base_name_for_download = f"downloaded_{Path(original_filename).stem[:50]}" # Limit stem length
+
+ # _get_unique_target_path expects a Path object for target_dir
+ local_image_path = _get_unique_target_path(target_dir_path, base_name_for_download, original_suffix)
+
+ with open(local_image_path, 'wb') as f:
+ # Re-fetch without stream=True if response was already consumed by raise_for_status check,
+ # or ensure streaming works correctly if the initial response object can be re-used.
+ # For simplicity, re-requesting after status check if necessary, or ensure stream is not prematurely closed.
+ # A simple way for non-huge files and to avoid stream issues with one-off downloads:
+ if response.raw.closed: # If stream was closed by raise_for_status or other means
+ with requests.Session() as s_final:
+ final_response = s_final.get(image_url_or_path, stream=False, timeout=300) # Not streaming for direct content write
+ final_response.raise_for_status()
+ f.write(final_response.content)
+ else: # Stream is still open
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+
+ dprint(f"Task {task_id_for_logging}: Image downloaded successfully to {local_image_path.resolve()}")
+ return str(local_image_path.resolve())
+
+ except requests.exceptions.RequestException as e_req:
+ dprint(f"Task {task_id_for_logging}: Error downloading image {image_url_or_path}: {e_req}. Returning original path.")
+ return image_url_or_path
+ except IOError as e_io:
+ dprint(f"Task {task_id_for_logging}: IO error saving image from {image_url_or_path}: {e_io}. Returning original path.")
+ return image_url_or_path
+ except Exception as e_gen:
+ dprint(f"Task {task_id_for_logging}: General error processing image URL {image_url_or_path}: {e_gen}. Returning original path.")
+ return image_url_or_path
+ else:
+ # Not an HTTP/HTTPS URL or no download directory specified
+ dprint(f"Task {task_id_for_logging}: Not downloading image (not URL or no target dir): {image_url_or_path}")
+ return image_url_or_path
+
+def image_to_frame(image_path_str: str | Path, target_resolution_wh: tuple[int, int] | None = None, task_id_for_logging: str | None = "generic_task", image_download_dir: Path | str | None = None) -> np.ndarray | None:
+ """
+ Load an image, optionally resize, and convert to BGR NumPy array.
+ If image_path_str is a URL and image_download_dir is provided, it attempts to download it first.
+ """
+ resolved_image_path_str = image_path_str # Default to original path
+
+ if isinstance(image_path_str, str): # Only attempt download if it's a string (potentially a URL)
+ resolved_image_path_str = download_image_if_url(image_path_str, image_download_dir, task_id_for_logging)
+
+ image_path = Path(resolved_image_path_str)
+
+ if not image_path.exists():
+ dprint(f"Task {task_id_for_logging}: Image file not found at {image_path} (original input: {image_path_str}).")
+ return None
+ try:
+ img = Image.open(image_path).convert("RGB") # Ensure RGB for consistent processing
+ if target_resolution_wh:
+ img = img.resize(target_resolution_wh, Image.Resampling.LANCZOS)
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ except Exception as e:
+ dprint(f"Error loading image {image_path} (original input {image_path_str}): {e}")
+ return None
+
+def _apply_strength_to_image(
+ image_path_input: Path | str, # Changed name to avoid confusion
+ strength: float,
+ output_path: Path,
+ target_resolution_wh: tuple[int, int] | None,
+ task_id_for_logging: str | None = "generic_task",
+ image_download_dir: Path | str | None = None
+) -> Path | None:
+ """
+ Applies a brightness adjustment (strength) to an image, optionally resizes, and saves it.
+ If image_path_input is a URL string and image_download_dir is provided, it attempts to download it first.
+ """
+ resolved_image_path_str = str(image_path_input)
+
+ if isinstance(image_path_input, str):
+ resolved_image_path_str = download_image_if_url(image_path_input, image_download_dir, task_id_for_logging)
+ # Check if image_path_input was a Path object representing a URL (less common for this function)
+ elif isinstance(image_path_input, Path) and image_path_input.as_posix().startswith(('http://', 'https://')):
+ resolved_image_path_str = download_image_if_url(image_path_input.as_posix(), image_download_dir, task_id_for_logging)
+
+ actual_image_path = Path(resolved_image_path_str)
+
+ if not actual_image_path.exists():
+ dprint(f"Task {task_id_for_logging}: Source image not found at {actual_image_path} (original input: {image_path_input}) for strength application.")
+ return None
+ try:
+ # Open the potentially downloaded or original local image
+ img = Image.open(actual_image_path).convert("RGB") # Ensure RGB
+
+ if target_resolution_wh:
+ img = img.resize(target_resolution_wh, Image.Resampling.LANCZOS)
+
+ # Apply the strength factor using PIL.ImageEnhance for brightness
+ enhancer = ImageEnhance.Brightness(img)
+ processed_img = enhancer.enhance(strength) # 'strength' is the factor for brightness
+
+ # Save the adjusted image
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ processed_img.save(output_path) # Save PIL image directly
+
+ dprint(f"Task {task_id_for_logging}: Applied strength {strength} to {actual_image_path.name}, saved to {output_path.name} with resolution {target_resolution_wh if target_resolution_wh else 'original'}")
+ return output_path
+ except Exception as e:
+ dprint(f"Task {task_id_for_logging}: Error in _apply_strength_to_image for {actual_image_path}: {e}")
+ traceback.print_exc()
+ return None
+
+def _copy_to_folder_with_unique_name(source_path: Path, target_dir: Path, base_name: str, extension: str) -> Path | None:
+ """Copies a file to a target directory with a unique name based on timestamp and random string."""
+ if not source_path:
+ dprint(f"COPY: Source path is None for {base_name}{extension}. Skipping copy.")
+ return None
+
+ source_path_obj = Path(source_path)
+ if not source_path_obj.exists():
+ dprint(f"COPY: Source file {source_path_obj} does not exist. Skipping copy.")
+ return None
+
+ # Sanitize extension for _get_unique_target_path
+ actual_extension = source_path_obj.suffix if source_path_obj.suffix else extension
+ if not actual_extension.startswith('.'):
+ actual_extension = '.' + actual_extension
+
+ # Determine unique target path using the new helper
+ target_file = _get_unique_target_path(target_dir, base_name, actual_extension)
+
+ try:
+ # target_dir.mkdir(parents=True, exist_ok=True) # _get_unique_target_path handles this
+ shutil.copy2(str(source_path_obj), str(target_file))
+ dprint(f"COPY: Copied {source_path_obj.name} to {target_file}")
+ return target_file # Return the path of the copied file
+ except Exception as e_copy:
+ dprint(f"COPY: Failed to copy {source_path_obj} to {target_file}: {e_copy}")
+ return None
+
+
+def get_image_dimensions_pil(image_path: str | Path) -> tuple[int, int]:
+ """Returns the dimensions of an image file as (width, height)."""
+ with Image.open(image_path) as img:
+ return img.size
+
+# Added to adjust the brightness of an image/frame.
+def _adjust_frame_brightness(frame: np.ndarray, factor: float) -> np.ndarray:
+ """Adjusts the brightness of a given frame.
+ The 'factor' is interpreted as a delta from the CLI argument:
+ - Positive factor (e.g., 0.1) makes it darker (target_alpha = 1.0 - 0.1 = 0.9).
+ - Negative factor (e.g., -0.1) makes it brighter (target_alpha = 1.0 - (-0.1) = 1.1).
+ - Zero factor means no change (target_alpha = 1.0).
+ """
+ # Convert the CLI-style factor to an alpha for cv2.convertScaleAbs
+ # CLI factor: positive = darker, negative = brighter
+ # cv2 alpha: >1 = brighter, <1 = darker
+ cv2_alpha = 1.0 - factor
+ return cv2.convertScaleAbs(frame, alpha=cv2_alpha, beta=0)
+
+def sm_get_unique_target_path(target_dir: Path, name_stem: str, suffix: str) -> Path:
+ """Generates a unique target Path in the given directory by appending a number if needed."""
+ if not suffix.startswith('.'):
+ suffix = f".{suffix}"
+
+ final_path = target_dir / f"{name_stem}{suffix}"
+ counter = 1
+ while final_path.exists():
+ final_path = target_dir / f"{name_stem}_{counter}{suffix}"
+ counter += 1
+ return final_path
+
+def load_pil_images(
+ paths_list_or_str: list[str] | str,
+ wgp_convert_func: callable,
+ image_download_dir: Path | str | None,
+ task_id_for_log: str,
+ dprint: callable
+) -> list[Any] | None:
+ """
+ Loads one or more images from paths or URLs, downloads them if necessary,
+ and applies a conversion function.
+ """
+ if not paths_list_or_str:
+ return None
+
+ paths_list = paths_list_or_str if isinstance(paths_list_or_str, list) else [paths_list_or_str]
+ images = []
+
+ for p_str in paths_list:
+ local_p_str = download_image_if_url(p_str, image_download_dir, task_id_for_log)
+ if not local_p_str:
+ dprint(f"[Task {task_id_for_log}] Skipping image as download_image_if_url returned nothing for: {p_str}")
+ continue
+
+ p = Path(local_p_str.strip())
+ if not p.is_file():
+ dprint(f"[Task {task_id_for_log}] load_pil_images: Image file not found after potential download: {p} (original: {p_str})")
+ continue
+ try:
+ img = Image.open(p)
+ images.append(wgp_convert_func(img))
+ except Exception as e:
+ print(f"[WARNING] Failed to load image {p}: {e}")
+
+ return images if images else None
+
+def _normalize_activated_loras_list(current_activated) -> list:
+ """Helper to ensure activated_loras is a proper list."""
+ if not isinstance(current_activated, list):
+ try:
+ return [str(item).strip() for item in str(current_activated).split(',') if item.strip()]
+ except Exception:
+ return []
+ return current_activated
+
+def _apply_special_lora_settings(task_id: str, lora_type: str, lora_basename: str, default_steps: int,
+ guidance_scale: float, flow_shift: float, ui_defaults: dict,
+ task_params_dict: dict, tea_cache_setting: float = None):
+ """
+ Shared helper to apply special LoRA settings (CausVid, LightI2X, etc.) to ui_defaults.
+ """
+ print(f"[Task ID: {task_id}] Applying {lora_type} LoRA settings.")
+
+ # [STEPS DEBUG] Add detailed debug for steps logic
+ print(f"[STEPS DEBUG] {lora_type}: task_params_dict keys: {list(task_params_dict.keys())}")
+ if "steps" in task_params_dict:
+ print(f"[STEPS DEBUG] {lora_type}: Found 'steps' = {task_params_dict['steps']}")
+ if "num_inference_steps" in task_params_dict:
+ print(f"[STEPS DEBUG] {lora_type}: Found 'num_inference_steps' = {task_params_dict['num_inference_steps']}")
+ if "video_length" in task_params_dict:
+ print(f"[STEPS DEBUG] {lora_type}: Found 'video_length' = {task_params_dict['video_length']}")
+
+ # Handle steps logic
+ if "steps" in task_params_dict:
+ ui_defaults["num_inference_steps"] = task_params_dict["steps"]
+ print(f"[Task ID: {task_id}] {lora_type} task using specified steps: {ui_defaults['num_inference_steps']}")
+ elif "num_inference_steps" in task_params_dict:
+ ui_defaults["num_inference_steps"] = task_params_dict["num_inference_steps"]
+ print(f"[Task ID: {task_id}] {lora_type} task using specified num_inference_steps: {ui_defaults['num_inference_steps']}")
+ else:
+ ui_defaults["num_inference_steps"] = default_steps
+ print(f"[Task ID: {task_id}] {lora_type} task defaulting to steps: {ui_defaults['num_inference_steps']}")
+
+ # Set guidance and flow shift
+ ui_defaults["guidance_scale"] = guidance_scale
+ ui_defaults["flow_shift"] = flow_shift
+
+ # Set tea cache if specified
+ if tea_cache_setting is not None:
+ ui_defaults["tea_cache_setting"] = tea_cache_setting
+
+ # Handle LoRA activation
+ current_activated = _normalize_activated_loras_list(ui_defaults.get("activated_loras", []))
+
+ if lora_basename not in current_activated:
+ current_activated.append(lora_basename)
+ ui_defaults["activated_loras"] = current_activated
+
+ # Handle multipliers - simple approach for build_task_state
+ current_multipliers_str = ui_defaults.get("loras_multipliers", "")
+ multipliers_list = [m.strip() for m in current_multipliers_str.split(" ") if m.strip()] if current_multipliers_str else []
+ while len(multipliers_list) < len(current_activated):
+ multipliers_list.insert(0, "1.0")
+ ui_defaults["loras_multipliers"] = " ".join(multipliers_list)
+
+# --- SM_RESTRUCTURE: Function moved from headless.py ---
+def build_task_state(wgp_mod, model_filename, task_params_dict, all_loras_for_model, image_download_dir: Path | str | None = None, apply_reward_lora: bool = False):
+ state = {
+ "model_filename": model_filename,
+ "validate_success": 1,
+ "advanced": True,
+ "gen": {"queue": [], "file_list": [], "file_settings_list": [], "prompt_no": 1, "prompts_max": 1},
+ "loras": all_loras_for_model,
+ }
+ model_type_key = wgp_mod.get_model_type(model_filename)
+ ui_defaults = wgp_mod.get_default_settings(model_filename).copy()
+
+ # Override with task_params from JSON, but preserve some crucial ones if CausVid is used
+ causvid_active = task_params_dict.get("use_causvid_lora", False)
+ lighti2x_active = task_params_dict.get("use_lighti2x_lora", False)
+
+ for key, value in task_params_dict.items():
+ if key not in ["output_sub_dir", "model", "task_id", "use_causvid_lora", "use_lighti2x_lora"]:
+ if (causvid_active or lighti2x_active) and key in ["steps", "guidance_scale", "flow_shift", "activated_loras", "loras_multipliers"]:
+ continue # These will be set by causvid/lighti2x logic if flag is true
+ ui_defaults[key] = value
+
+ ui_defaults["prompt"] = task_params_dict.get("prompt", "Default prompt")
+ ui_defaults["resolution"] = task_params_dict.get("resolution", "832x480")
+ # Allow task to specify frames/video_length, steps, guidance_scale, flow_shift unless overridden by CausVid
+ if not (causvid_active or lighti2x_active):
+ ui_defaults["video_length"] = task_params_dict.get("frames", task_params_dict.get("video_length", 81))
+ ui_defaults["num_inference_steps"] = task_params_dict.get("steps", task_params_dict.get("num_inference_steps", 30))
+ ui_defaults["guidance_scale"] = task_params_dict.get("guidance_scale", ui_defaults.get("guidance_scale", 5.0))
+ ui_defaults["flow_shift"] = task_params_dict.get("flow_shift", ui_defaults.get("flow_shift", 3.0))
+ else: # CausVid or LightI2X specific defaults if not touched by their logic yet
+ ui_defaults["video_length"] = task_params_dict.get("frames", task_params_dict.get("video_length", 81))
+ # steps, guidance_scale, flow_shift will be set below by specialised logic
+
+ ui_defaults["seed"] = task_params_dict.get("seed", -1)
+ ui_defaults["lset_name"] = ""
+
+ current_task_id_for_log = task_params_dict.get('task_id', 'build_task_state_unknown')
+
+ if task_params_dict.get("image_start_paths"):
+ loaded = load_pil_images(
+ task_params_dict["image_start_paths"],
+ wgp_mod.convert_image,
+ image_download_dir,
+ current_task_id_for_log,
+ dprint
+ )
+ if loaded: ui_defaults["image_start"] = loaded
+
+ if task_params_dict.get("image_end_paths"):
+ loaded = load_pil_images(
+ task_params_dict["image_end_paths"],
+ wgp_mod.convert_image,
+ image_download_dir,
+ current_task_id_for_log,
+ dprint
+ )
+ if loaded: ui_defaults["image_end"] = loaded
+
+ if task_params_dict.get("image_refs_paths"):
+ loaded = load_pil_images(
+ task_params_dict["image_refs_paths"],
+ wgp_mod.convert_image,
+ image_download_dir,
+ current_task_id_for_log,
+ dprint
+ )
+ if loaded: ui_defaults["image_refs"] = loaded
+
+ for key in ["video_source_path", "video_guide_path", "video_mask_path", "audio_guide_path"]:
+ if task_params_dict.get(key):
+ ui_defaults[key.replace("_path","")] = task_params_dict[key]
+
+ if task_params_dict.get("prompt_enhancer_mode"):
+ ui_defaults["prompt_enhancer"] = task_params_dict["prompt_enhancer_mode"]
+ wgp_mod.server_config["enhancer_enabled"] = 1
+ elif "prompt_enhancer" not in task_params_dict:
+ ui_defaults["prompt_enhancer"] = ""
+ wgp_mod.server_config["enhancer_enabled"] = 0
+
+ # --- Custom LoRA Handling (e.g., from lora_name) ---
+ custom_lora_name_stem = task_params_dict.get("lora_name")
+ task_id_for_dprint = task_params_dict.get('task_id', 'N/A') # For logging
+
+ if custom_lora_name_stem:
+ custom_lora_filename = f"{custom_lora_name_stem}.safetensors"
+ dprint(f"[Task ID: {task_id_for_dprint}] Custom LoRA specified via lora_name: {custom_lora_filename}")
+
+ # Ensure activated_loras is a list
+ activated_loras_val = ui_defaults.get("activated_loras", [])
+ if isinstance(activated_loras_val, str):
+ # Handles comma-separated string from task_params or previous logic
+ current_activated_list = [str(item).strip() for item in activated_loras_val.split(',') if item.strip()]
+ elif isinstance(activated_loras_val, list):
+ current_activated_list = list(activated_loras_val) # Ensure it's a mutable copy
+ else:
+ dprint(f"[Task ID: {task_id_for_dprint}] Unexpected type for activated_loras: {type(activated_loras_val)}. Initializing as empty list.")
+ current_activated_list = []
+
+ if custom_lora_filename not in current_activated_list:
+ current_activated_list.append(custom_lora_filename)
+ dprint(f"[Task ID: {task_id_for_dprint}] Added '{custom_lora_filename}' to activated_loras list: {current_activated_list}")
+
+ # Handle multipliers: Add a default "1.0" if a LoRA was added and multipliers are potentially mismatched
+ loras_multipliers_str = ui_defaults.get("loras_multipliers", "")
+ if isinstance(loras_multipliers_str, (list, tuple)):
+ loras_multipliers_list = [str(m).strip() for m in loras_multipliers_str if str(m).strip()] # Convert all to string and clean
+ elif isinstance(loras_multipliers_str, str):
+ loras_multipliers_list = [m.strip() for m in loras_multipliers_str.split(" ") if m.strip()] # Space-separated string
+ else:
+ dprint(f"[Task ID: {task_id_for_dprint}] Unexpected type for loras_multipliers: {type(loras_multipliers_str)}. Initializing as empty list.")
+ loras_multipliers_list = []
+
+ # If number of multipliers is less than activated LoRAs, pad with "1.0"
+ while len(loras_multipliers_list) < len(current_activated_list):
+ loras_multipliers_list.append("1.0")
+ dprint(f"[Task ID: {task_id_for_dprint}] Padded loras_multipliers with '1.0'. Now: {loras_multipliers_list}")
+
+ ui_defaults["loras_multipliers"] = " ".join(loras_multipliers_list)
+ else:
+ dprint(f"[Task ID: {task_id_for_dprint}] Custom LoRA '{custom_lora_filename}' already in activated_loras list.")
+
+ ui_defaults["activated_loras"] = current_activated_list # Update ui_defaults
+ # --- End Custom LoRA Handling ---
+
+ # Apply special LoRA settings using shared helper
+ if causvid_active:
+ _apply_special_lora_settings(
+ current_task_id_for_log, "CausVid",
+ "Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
+ default_steps=9, guidance_scale=1.0, flow_shift=1.0,
+ ui_defaults=ui_defaults, task_params_dict=task_params_dict
+ )
+
+ if lighti2x_active:
+ _apply_special_lora_settings(
+ current_task_id_for_log, "LightI2X",
+ "wan_lcm_r16_fp32_comfy.safetensors",
+ default_steps=4, guidance_scale=1.0, flow_shift=5.0,
+ ui_defaults=ui_defaults, task_params_dict=task_params_dict,
+ tea_cache_setting=0.0
+ )
+ # Additional LightI2X-specific settings
+ ui_defaults["sample_solver"] = "unipc"
+ ui_defaults["denoise_strength"] = 1.0
+
+ if apply_reward_lora:
+ print(f"[Task ID: {task_params_dict.get('task_id')}] Applying Reward LoRA settings.")
+
+ reward_lora = {"filename": "Wan2.1-Fun-14B-InP-MPS_reward_lora_wgp.safetensors", "strength": "0.5"}
+
+ # Get current activated LoRAs
+ current_activated = ui_defaults.get("activated_loras", [])
+ if not isinstance(current_activated, list):
+ try:
+ current_activated = [str(item).strip() for item in str(current_activated).split(',') if item.strip()]
+ except:
+ current_activated = []
+
+ # Get current multipliers
+ current_multipliers_str = ui_defaults.get("loras_multipliers", "")
+ if isinstance(current_multipliers_str, (list, tuple)):
+ current_multipliers_list = [str(m).strip() for m in current_multipliers_str if str(m).strip()]
+ elif isinstance(current_multipliers_str, str):
+ current_multipliers_list = [m.strip() for m in current_multipliers_str.split(" ") if m.strip()]
+ else:
+ current_multipliers_list = []
+
+ # Pad multipliers to match activated LoRAs before creating map
+ while len(current_multipliers_list) < len(current_activated):
+ current_multipliers_list.append("1.0")
+
+ # Create a dictionary to map lora to multiplier for easy update (preserves order in Python 3.7+)
+ lora_mult_map = dict(zip(current_activated, current_multipliers_list))
+
+ # Add/update reward lora
+ lora_mult_map[reward_lora['filename']] = reward_lora['strength']
+
+ ui_defaults["activated_loras"] = list(lora_mult_map.keys())
+ ui_defaults["loras_multipliers"] = " ".join(list(lora_mult_map.values()))
+ dprint(f"Reward LoRA applied. Activated: {ui_defaults['activated_loras']}, Multipliers: {ui_defaults['loras_multipliers']}")
+
+ # Apply additional LoRAs that may have been passed via task params (e.g. from travel orchestrator)
+ processed_additional_loras = task_params_dict.get("processed_additional_loras", {})
+ if processed_additional_loras:
+ dprint(f"[Task ID: {task_id_for_dprint}] Applying processed additional LoRAs: {processed_additional_loras}")
+
+ # Get current activated LoRAs and multipliers again, as they may have been modified by other logic.
+ current_activated = ui_defaults.get("activated_loras", [])
+ if not isinstance(current_activated, list):
+ try:
+ current_activated = [str(item).strip() for item in str(current_activated).split(',') if item.strip()]
+ except:
+ current_activated = []
+
+ current_multipliers_str = ui_defaults.get("loras_multipliers", "")
+ if isinstance(current_multipliers_str, (list, tuple)):
+ current_multipliers_list = [str(m).strip() for m in current_multipliers_str if str(m).strip()]
+ elif isinstance(current_multipliers_str, str):
+ current_multipliers_list = [m.strip() for m in current_multipliers_str.split(" ") if m.strip()]
+ else:
+ current_multipliers_list = []
+
+ # Pad multipliers to match activated LoRAs before creating map
+ while len(current_multipliers_list) < len(current_activated):
+ current_multipliers_list.append("1.0")
+
+ lora_mult_map = dict(zip(current_activated, current_multipliers_list))
+
+ # Add/update additional loras - this will overwrite strength if lora was already present
+ lora_mult_map.update(processed_additional_loras)
+
+ ui_defaults["activated_loras"] = list(lora_mult_map.keys())
+ ui_defaults["loras_multipliers"] = " ".join(list(lora_mult_map.values()))
+ dprint(f"Additional LoRAs applied. Final Activated: {ui_defaults['activated_loras']}, Final Multipliers: {ui_defaults['loras_multipliers']}")
+
+ state[model_type_key] = ui_defaults
+ return state, ui_defaults
+
+def prepare_output_path(
+ task_id: str,
+ filename: str,
+ main_output_dir_base: Path,
+ *,
+ dprint=lambda *_: None,
+ custom_output_dir: str | Path | None = None
+) -> tuple[Path, str]:
+ """
+ Prepares the output path for a task's artifact.
+
+ If `custom_output_dir` is provided, it's used as the base. Otherwise,
+ the output is placed in a directory named after the task_id inside
+ `main_output_dir_base`.
+ """
+ # Import DB configuration lazily to avoid circular dependencies.
+ try:
+ from source import db_operations as db_ops # type: ignore
+ except Exception: # pragma: no cover
+ db_ops = None
+
+ # Decide base directory for the file
+ if custom_output_dir:
+ output_dir_for_task = Path(custom_output_dir)
+ dprint(f"Task {task_id}: Using custom output directory: {output_dir_for_task}")
+ else:
+ if db_ops and db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH:
+ sqlite_db_parent = Path(db_ops.SQLITE_DB_PATH).resolve().parent
+ output_dir_for_task = sqlite_db_parent / "public" / "files"
+ dprint(f"Task {task_id}: Using SQLite public files directory: {output_dir_for_task}")
+ else:
+ # Flatten: put all artefacts directly in the main output dir, no per-task folder
+ output_dir_for_task = main_output_dir_base
+
+ # To avoid name collisions we prefix the filename with the task_id
+ if not filename.startswith(task_id):
+ filename = f"{task_id}_{filename}"
+
+ dprint(
+ f"Task {task_id}: Using flattened output directory: {output_dir_for_task} (filename '{filename}')"
+ )
+
+ output_dir_for_task.mkdir(parents=True, exist_ok=True)
+
+ final_save_path = output_dir_for_task / filename
+
+ # Build DB path string
+ try:
+ if db_ops and db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH:
+ sqlite_db_parent = Path(db_ops.SQLITE_DB_PATH).resolve().parent
+ db_output_location = str(final_save_path.relative_to(sqlite_db_parent / "public"))
+ else:
+ db_output_location = str(final_save_path.relative_to(Path.cwd()))
+ except ValueError:
+ db_output_location = str(final_save_path.resolve())
+
+ dprint(f"Task {task_id}: final_save_path='{final_save_path}', db_output_location='{db_output_location}'")
+
+ return final_save_path, db_output_location
+
+def prepare_output_path_with_upload(
+ task_id: str,
+ filename: str,
+ main_output_dir_base: Path,
+ *,
+ dprint=lambda *_: None,
+ custom_output_dir: str | Path | None = None
+) -> tuple[Path, str]:
+ """
+ Prepares the output path for a task's artifact and handles Supabase upload if configured.
+
+ Returns:
+ tuple[Path, str]: (local_file_path, db_output_location)
+ - local_file_path: Where to save the file locally (for generation)
+ - db_output_location: What to store in the database (local path or Supabase URL)
+ """
+ # Import DB configuration lazily to avoid circular dependencies.
+ try:
+ from source import db_operations as db_ops # type: ignore
+ except Exception: # pragma: no cover
+ db_ops = None
+
+ # First, get the local path where we'll save the file
+ local_save_path, initial_db_location = prepare_output_path(
+ task_id, filename, main_output_dir_base,
+ dprint=dprint, custom_output_dir=custom_output_dir
+ )
+
+ # Return the local path for now - we'll handle Supabase upload after file is created
+ return local_save_path, initial_db_location
+
+def upload_and_get_final_output_location(
+ local_file_path: Path,
+ supabase_object_name: str, # This parameter is now unused but kept for compatibility
+ initial_db_location: str,
+ *,
+ dprint=lambda *_: None
+) -> str:
+ """
+ Returns the local file path. Upload is now handled by the edge function.
+
+ Args:
+ local_file_path: Path to the local file
+ supabase_object_name: Unused (kept for compatibility)
+ initial_db_location: The initial DB location (local path)
+ dprint: Debug print function
+
+ Returns:
+ str: Local file path (upload now handled by edge function)
+ """
+ # Edge function will handle the upload, so we just return the local path
+ dprint(f"File ready for edge function upload: {local_file_path}")
+ return str(local_file_path.resolve())
+
+def create_mask_video_from_inactive_indices(
+ total_frames: int,
+ resolution_wh: tuple[int, int],
+ inactive_frame_indices: set[int] | list[int],
+ output_path: Path | str,
+ fps: int = 16,
+ task_id_for_logging: str = "unknown",
+ *, dprint = print
+) -> Path | None:
+ """
+ Create a mask video where:
+ - Black frames (0) = inactive/keep original - don't edit these frames
+ - White frames (255) = active/generate - model should generate these frames
+
+ Args:
+ total_frames: Total number of frames in the video
+ resolution_wh: (width, height) tuple for video resolution
+ inactive_frame_indices: Set or list of frame indices that should be black (inactive)
+ output_path: Where to save the mask video
+ fps: Frames per second for the output video
+ task_id_for_logging: Task ID for debug logging
+ dprint: Print function for logging
+
+ Returns:
+ Path to created mask video, or None if creation failed
+ """
+ try:
+ if total_frames <= 0:
+ dprint(f"[WARNING] Task {task_id_for_logging}: Cannot create mask video with {total_frames} frames")
+ return None
+
+ h, w = resolution_wh[1], resolution_wh[0] # height, width
+ inactive_set = set(inactive_frame_indices) if not isinstance(inactive_frame_indices, set) else inactive_frame_indices
+
+ dprint(f"Task {task_id_for_logging}: Creating mask video - total_frames={total_frames}, "
+ f"inactive_indices={sorted(list(inactive_set))[:10]}{'...' if len(inactive_set) > 10 else ''}")
+
+ # Create mask frames: 0 (black) for inactive, 255 (white) for active
+ mask_frames_buf: list[np.ndarray] = [
+ np.full((h, w, 3), 0 if idx in inactive_set else 255, dtype=np.uint8)
+ for idx in range(total_frames)
+ ]
+
+ created_mask_video = create_video_from_frames_list(
+ mask_frames_buf,
+ Path(output_path),
+ fps,
+ resolution_wh
+ )
+
+ if created_mask_video and created_mask_video.exists():
+ dprint(f"Task {task_id_for_logging}: Mask video created successfully at {created_mask_video}")
+ return created_mask_video
+ else:
+ dprint(f"[WARNING] Task {task_id_for_logging}: Failed to create mask video at {output_path}")
+ return None
+
+ except Exception as e:
+ dprint(f"[ERROR] Task {task_id_for_logging}: Mask video creation failed: {e}")
+ return None
+
+def create_simple_first_frame_mask_video(
+ total_frames: int,
+ resolution_wh: tuple[int, int],
+ output_path: Path | str,
+ fps: int = 16,
+ task_id_for_logging: str = "unknown",
+ *, dprint = print
+) -> Path | None:
+ """
+ Convenience function to create a mask video where only the first frame is inactive (black).
+ This is useful for workflows like different_perspective where you want to keep the first frame unchanged
+ and generate the rest.
+
+ Returns:
+ Path to created mask video, or None if creation failed
+ """
+ return create_mask_video_from_inactive_indices(
+ total_frames=total_frames,
+ resolution_wh=resolution_wh,
+ inactive_frame_indices={0}, # Only first frame is inactive
+ output_path=output_path,
+ fps=fps,
+ task_id_for_logging=task_id_for_logging,
+ dprint=dprint
+ )
+
+def wait_for_file_stable(path: Path | str, checks: int = 3, interval: float = 1.0, *, dprint=print) -> bool:
+ """Return True when the file size stays constant for a few consecutive checks.
+ Useful to make sure long-running encoders have finished writing before we
+ copy/move the file.
+ """
+ p = Path(path)
+ if not p.exists():
+ return False
+ last_size = p.stat().st_size
+ stable_count = 0
+ for _ in range(checks):
+ time.sleep(interval)
+ new_size = p.stat().st_size
+ if new_size == last_size and new_size > 0:
+ stable_count += 1
+ if stable_count >= checks - 1:
+ return True
+ else:
+ stable_count = 0
+ last_size = new_size
+ return False
+
+def report_orchestrator_failure(task_params_dict: dict, error_msg: str, dprint: callable = print) -> None:
+ """Update the parent orchestrator task to FAILED when a sub-task encounters a fatal error.
+
+ Args:
+ task_params_dict: The params payload of the *current* sub-task.
+ It is expected to contain a reference to the orchestrator via one
+ of the standard keys (e.g. ``orchestrator_task_id_ref``).
+ error_msg: Human-readable message describing the failure.
+ dprint: Debug print helper (typically passed from the caller).
+ """
+ # Defer import to avoid potential circular dependencies at module import time
+ try:
+ from source import db_operations as db_ops # type: ignore
+ except Exception as e: # pragma: no cover
+ dprint(f"[report_orchestrator_failure] Could not import db_operations: {e}")
+ return
+
+ orchestrator_id = None
+ # Common payload keys that may reference the orchestrator task
+ for key in (
+ "orchestrator_task_id_ref",
+ "orchestrator_task_id",
+ "parent_orchestrator_task_id",
+ "orchestrator_id",
+ ):
+ orchestrator_id = task_params_dict.get(key)
+ if orchestrator_id:
+ break
+
+ if not orchestrator_id:
+ dprint(
+ f"[report_orchestrator_failure] No orchestrator reference found in payload. Message: {error_msg}"
+ )
+ return
+
+ # Truncate very long messages to avoid DB column overflow
+ truncated_msg = error_msg[:500]
+
+ try:
+ db_ops.update_task_status(
+ orchestrator_id,
+ db_ops.STATUS_FAILED,
+ truncated_msg,
+ )
+ dprint(
+ f"[report_orchestrator_failure] Marked orchestrator task {orchestrator_id} as FAILED with message: {truncated_msg}"
+ )
+ except Exception as e_update: # pragma: no cover
+ dprint(
+ f"[report_orchestrator_failure] Failed to update orchestrator status for {orchestrator_id}: {e_update}"
+ )
+
+def validate_lora_file(file_path: Path, filename: str) -> tuple[bool, str]:
+ """
+ Validates a LoRA file for size and format integrity.
+
+ Returns:
+ (is_valid, error_message)
+ """
+ if not file_path.exists():
+ return False, f"File does not exist: {file_path}"
+
+ file_size = file_path.stat().st_size
+
+ # Known LoRA size ranges (in bytes)
+ # These are based on common LoRA architectures and rank sizes
+ LORA_SIZE_RANGES = {
+ # Very small LoRAs (rank 4-8)
+ 'tiny': (1_000_000, 50_000_000), # 1MB - 50MB
+ # Standard LoRAs (rank 16-32)
+ 'standard': (50_000_000, 500_000_000), # 50MB - 500MB
+ # Large LoRAs (rank 64+) or full model fine-tunes
+ 'large': (500_000_000, 5_000_000_000), # 500MB - 5GB
+ # Extremely large (full model weights)
+ 'xlarge': (5_000_000_000, 50_000_000_000) # 5GB - 50GB
+ }
+
+ # Check if file size is within any reasonable range
+ in_valid_range = any(
+ min_size <= file_size <= max_size
+ for min_size, max_size in LORA_SIZE_RANGES.values()
+ )
+
+ if not in_valid_range:
+ if file_size < 1_000_000: # Less than 1MB
+ return False, f"File too small ({file_size:,} bytes) - likely corrupted or incomplete download"
+ elif file_size > 50_000_000_000: # More than 50GB
+ return False, f"File too large ({file_size:,} bytes) - likely not a LoRA file"
+
+ # For safetensors files, try to open and inspect
+ if filename.endswith('.safetensors'):
+ try:
+ import safetensors.torch as st
+ with st.safe_open(file_path, framework="pt") as f:
+ # Get metadata to verify it's actually a LoRA
+ metadata = f.metadata()
+ keys = list(f.keys())
+
+ # LoRAs typically have keys like "lora_down.weight", "lora_up.weight", etc.
+ lora_indicators = ['lora_down', 'lora_up', 'lora.down', 'lora.up', 'lora_A', 'lora_B']
+ has_lora_keys = any(indicator in key for key in keys for indicator in lora_indicators)
+
+ if not has_lora_keys and len(keys) > 100:
+ # Might be a full model checkpoint rather than a LoRA
+ print(f"[WARNING] {filename} appears to be a full model checkpoint ({len(keys)} tensors) rather than a LoRA")
+ elif not has_lora_keys:
+ print(f"[WARNING] {filename} doesn't appear to contain LoRA weights (no lora_* keys found)")
+
+ # Check for reasonable number of parameters
+ if len(keys) == 0:
+ return False, "Safetensors file contains no tensors"
+ elif len(keys) > 10000:
+ print(f"[WARNING] {filename} contains many tensors ({len(keys)}) - might be a full model")
+
+ except ImportError:
+ print(f"[WARNING] safetensors not available for detailed validation of {filename}")
+ except Exception as e:
+ return False, f"Safetensors file appears corrupted: {e}"
+
+ # Additional checks for common corruption patterns
+ if file_size == 0:
+ return False, "File is empty"
+
+ # For binary files, check they don't start with common error HTML patterns
+ try:
+ with open(file_path, 'rb') as f:
+ first_bytes = f.read(1024)
+ if first_bytes.startswith(b' dict:
+ """
+ Checks all LoRA files in a directory for integrity issues.
+
+ Args:
+ lora_dir: Directory containing LoRA files
+ fix_issues: If True, removes corrupted files
+
+ Returns:
+ Dictionary with validation results
+ """
+ lora_dir = Path(lora_dir)
+ if not lora_dir.exists():
+ return {"error": f"Directory does not exist: {lora_dir}"}
+
+ results = {
+ "total_files": 0,
+ "valid_files": 0,
+ "invalid_files": 0,
+ "issues": [],
+ "summary": []
+ }
+
+ # Look for LoRA-like files
+ lora_extensions = ['.safetensors', '.bin', '.pt', '.pth']
+ lora_files = []
+
+ for ext in lora_extensions:
+ lora_files.extend(lora_dir.glob(f"*{ext}"))
+ lora_files.extend(lora_dir.glob(f"**/*{ext}")) # Include subdirectories
+
+ # Filter to likely LoRA files
+ lora_files = [f for f in lora_files if 'lora' in f.name.lower() or f.suffix == '.safetensors']
+
+ results["total_files"] = len(lora_files)
+
+ for lora_file in lora_files:
+ is_valid, validation_msg = validate_lora_file(lora_file, lora_file.name)
+
+ if is_valid:
+ results["valid_files"] += 1
+ results["summary"].append(f"✓ {lora_file.name}: {validation_msg}")
+ else:
+ results["invalid_files"] += 1
+ issue_msg = f"✗ {lora_file.name}: {validation_msg}"
+ results["issues"].append(issue_msg)
+ results["summary"].append(issue_msg)
+
+ if fix_issues:
+ try:
+ lora_file.unlink()
+ results["summary"].append(f" → Removed corrupted file: {lora_file.name}")
+ except Exception as e:
+ results["summary"].append(f" → Failed to remove {lora_file.name}: {e}")
+
+ return results
+
+# --------------------------------------------------------------------------------------------------
diff --git a/source/db_operations.py b/source/db_operations.py
new file mode 100644
index 000000000..59b2aeb69
--- /dev/null
+++ b/source/db_operations.py
@@ -0,0 +1,1199 @@
+# source/sm_functions/db_operations.py
+import os
+import sys
+import json
+import time
+import traceback
+import datetime
+import sqlite3
+import urllib.parse
+import threading
+import httpx # For calling Supabase Edge Function
+from pathlib import Path
+import base64 # Added for JWT decoding
+
+try:
+ from supabase import create_client, Client as SupabaseClient
+except ImportError:
+ SupabaseClient = None
+
+# -----------------------------------------------------------------------------
+# Global DB Configuration (will be set by headless.py)
+# -----------------------------------------------------------------------------
+DB_TYPE = "sqlite"
+PG_TABLE_NAME = "tasks"
+SQLITE_DB_PATH = "tasks.db"
+SUPABASE_URL = None
+SUPABASE_SERVICE_KEY = None
+SUPABASE_VIDEO_BUCKET = "image_uploads"
+SUPABASE_CLIENT: SupabaseClient | None = None
+SUPABASE_EDGE_COMPLETE_TASK_URL: str | None = None # Optional override for edge function
+SUPABASE_ACCESS_TOKEN: str | None = None # Will be set by headless.py
+SUPABASE_EDGE_CREATE_TASK_URL: str | None = None # Will be set by headless.py
+SUPABASE_EDGE_CLAIM_TASK_URL: str | None = None # Will be set by headless.py
+
+sqlite_lock = threading.Lock()
+
+SQLITE_MAX_RETRIES = 5
+SQLITE_RETRY_DELAY = 0.5 # seconds
+
+# -----------------------------------------------------------------------------
+# Status Constants
+# -----------------------------------------------------------------------------
+STATUS_QUEUED = "Queued"
+STATUS_IN_PROGRESS = "In Progress"
+STATUS_COMPLETE = "Complete"
+STATUS_FAILED = "Failed"
+
+# -----------------------------------------------------------------------------
+# Debug / Verbose Logging Helpers
+# -----------------------------------------------------------------------------
+debug_mode = False
+
+def dprint(msg: str):
+ """Print a debug message if debug_mode is enabled."""
+ if debug_mode:
+ print(f"[DEBUG {datetime.datetime.now().isoformat()}] {msg}")
+
+# -----------------------------------------------------------------------------
+# Internal Helpers
+# -----------------------------------------------------------------------------
+
+def execute_sqlite_with_retry(db_path_str: str, operation_func, *args, **kwargs):
+ """Execute SQLite operations with retry logic for handling locks and I/O errors"""
+ for attempt in range(SQLITE_MAX_RETRIES):
+ try:
+ with sqlite_lock: # Ensure only one thread accesses SQLite at a time
+ conn = sqlite3.connect(db_path_str, timeout=30.0) # 30 second timeout
+ conn.execute("PRAGMA journal_mode=WAL") # Enable WAL mode for better concurrency
+ conn.execute("PRAGMA synchronous=NORMAL") # Balance between safety and performance
+ conn.execute("PRAGMA busy_timeout=30000") # 30 second busy timeout
+ try:
+ result = operation_func(conn, *args, **kwargs)
+ conn.commit()
+ return result
+ finally:
+ conn.close()
+ except (sqlite3.OperationalError, sqlite3.DatabaseError) as e:
+ error_msg = str(e).lower()
+ if "database is locked" in error_msg or "disk i/o error" in error_msg or "database disk image is malformed" in error_msg:
+ if attempt < SQLITE_MAX_RETRIES - 1:
+ wait_time = SQLITE_RETRY_DELAY * (2 ** attempt) # Exponential backoff
+ print(f"SQLite error on attempt {attempt + 1}: {e}. Retrying in {wait_time:.1f}s...")
+ time.sleep(wait_time)
+ continue
+ else:
+ print(f"SQLite error after {SQLITE_MAX_RETRIES} attempts: {e}")
+ raise
+ else:
+ # For other SQLite errors, don't retry
+ raise
+ except Exception as e:
+ # For non-SQLite errors, don't retry
+ raise
+
+ raise sqlite3.OperationalError(f"Failed to execute SQLite operation after {SQLITE_MAX_RETRIES} attempts")
+
+# -----------------------------------------------------------------------------
+# Internal Helpers for Supabase
+# -----------------------------------------------------------------------------
+
+def _get_user_id_from_jwt(jwt_str: str) -> str | None:
+ """Decodes a JWT and extracts the 'sub' (user ID) claim without validation."""
+ if not jwt_str:
+ return None
+ try:
+ # JWT is composed of header.payload.signature
+ _, payload_b64, _ = jwt_str.split('.')
+ # The payload is base64 encoded. It needs to be padded to be decoded correctly.
+ payload_b64 += '=' * (-len(payload_b64) % 4)
+ payload_json = base64.b64decode(payload_b64).decode('utf-8')
+ payload = json.loads(payload_json)
+ user_id = payload.get('sub')
+ dprint(f"JWT Decode: Extracted user ID (sub): {user_id}")
+ return user_id
+ except Exception as e:
+ dprint(f"[ERROR] Could not decode JWT to get user ID: {e}")
+ return None
+
+def _is_jwt_token(token_str: str) -> bool:
+ """
+ Checks if a token string looks like a JWT (has 3 parts separated by dots).
+ """
+ if not token_str:
+ return False
+ parts = token_str.split('.')
+ return len(parts) == 3
+
+def _mark_task_failed_via_edge_function(task_id_str: str, error_message: str):
+ """Mark a task as failed using the update-task-status Edge Function"""
+ try:
+ edge_url = (
+ os.getenv("SUPABASE_EDGE_UPDATE_TASK_URL")
+ or (f"{SUPABASE_URL.rstrip('/')}/functions/v1/update-task-status" if SUPABASE_URL else None)
+ )
+
+ if not edge_url:
+ print(f"[ERROR] No update-task-status edge function URL available for marking task {task_id_str} as failed")
+ return
+
+ headers = {"Content-Type": "application/json"}
+ if SUPABASE_ACCESS_TOKEN:
+ headers["Authorization"] = f"Bearer {SUPABASE_ACCESS_TOKEN}"
+
+ payload = {
+ "task_id": task_id_str,
+ "status": STATUS_FAILED,
+ "output_location": error_message
+ }
+
+ resp = httpx.post(edge_url, json=payload, headers=headers, timeout=30)
+
+ if resp.status_code == 200:
+ dprint(f"[DEBUG] Successfully marked task {task_id_str} as Failed via Edge Function")
+ else:
+ print(f"[ERROR] Failed to mark task {task_id_str} as Failed: {resp.status_code} - {resp.text}")
+
+ except Exception as e:
+ print(f"[ERROR] Exception marking task {task_id_str} as Failed: {e}")
+
+# -----------------------------------------------------------------------------
+# Public Database Functions
+# -----------------------------------------------------------------------------
+
+def _migrate_sqlite_schema(db_path_str: str):
+ """Applies necessary schema migrations to an existing SQLite database."""
+ dprint(f"SQLite Migration: Checking schema for {db_path_str}...")
+ try:
+ def migration_operations(conn):
+ cursor = conn.cursor()
+
+ # --- Check if 'tasks' table exists before attempting migrations ---
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='tasks'")
+ if cursor.fetchone() is None:
+ dprint("SQLite Migration: 'tasks' table does not exist. Skipping schema migration steps. init_db will create it.")
+ return True # Indicate success as there's nothing to migrate on a non-existent table
+
+ # Check if task_type column exists
+ cursor.execute(f"PRAGMA table_info(tasks)")
+ columns = [row[1] for row in cursor.fetchall()]
+ task_type_column_exists = 'task_type' in columns
+
+ if not task_type_column_exists:
+ dprint("SQLite Migration: 'task_type' column not found. Adding it.")
+ cursor.execute("ALTER TABLE tasks ADD COLUMN task_type TEXT") # Add as nullable first
+ conn.commit() # Commit alter table before data migration
+ dprint("SQLite Migration: 'task_type' column added.")
+ else:
+ dprint("SQLite Migration: 'task_type' column already exists.")
+
+ # --- Add/Rename dependant_on column if not exists --- (Section 2.2)
+ dependant_on_column_exists = 'dependant_on' in columns
+ depends_on_column_exists_old_name = 'depends_on' in columns # Check for the previously incorrect name
+
+ if not dependant_on_column_exists:
+ if depends_on_column_exists_old_name:
+ dprint("SQLite Migration: Found old 'depends_on' column. Renaming to 'dependant_on'.")
+ try:
+ # Ensure no other column is already named 'dependant_on' before renaming
+ # This scenario is unlikely if migrations are run sequentially but good for robustness
+ if 'dependant_on' not in columns:
+ cursor.execute("ALTER TABLE tasks RENAME COLUMN depends_on TO dependant_on")
+ dprint("SQLite Migration: Renamed 'depends_on' to 'dependant_on'.")
+ dependant_on_column_exists = True # Mark as existing now
+ else:
+ dprint("SQLite Migration: 'dependant_on' column already exists. Skipping rename of 'depends_on'.")
+ except sqlite3.OperationalError as e_rename:
+ dprint(f"SQLite Migration: Could not rename 'depends_on' to 'dependant_on' (perhaps 'dependant_on' already exists or other issue): {e_rename}. Will attempt to ADD 'dependant_on' if it truly doesn't exist after this.")
+ # Re-check columns after attempted rename or if it failed
+ cursor.execute(f"PRAGMA table_info(tasks)")
+ rechecked_columns = [row[1] for row in cursor.fetchall()]
+ if 'dependant_on' not in rechecked_columns:
+ dprint("SQLite Migration: 'dependant_on' still not found after rename attempt, adding new column.")
+ cursor.execute("ALTER TABLE tasks ADD COLUMN dependant_on TEXT NULL")
+ dprint("SQLite Migration: 'dependant_on' column added.")
+ else:
+ dprint("SQLite Migration: 'dependant_on' column now exists (possibly due to a concurrent migration or complex rename scenario).")
+ dependant_on_column_exists = True
+ else:
+ dprint("SQLite Migration: 'dependant_on' column not found and no 'depends_on' to rename. Adding new 'dependant_on' column.")
+ cursor.execute("ALTER TABLE tasks ADD COLUMN dependant_on TEXT NULL")
+ dprint("SQLite Migration: 'dependant_on' column added.")
+ else:
+ dprint("SQLite Migration: 'dependant_on' column already exists.")
+ # --- End add/rename dependant_on column ---
+
+ # Ensure the index for dependant_on exists
+ cursor.execute("PRAGMA index_list(tasks)")
+ indexes = [row[1] for row in cursor.fetchall()]
+ if 'idx_dependant_on' not in indexes:
+ dprint("SQLite Migration: Creating 'idx_dependant_on' index.")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_dependant_on ON tasks(dependant_on)")
+ else:
+ dprint("SQLite Migration: 'idx_dependant_on' index already exists.")
+
+ # --- Add generation_started_at column if not exists ---
+ generation_started_at_column_exists = 'generation_started_at' in columns
+ if not generation_started_at_column_exists:
+ dprint("SQLite Migration: 'generation_started_at' column not found. Adding it.")
+ cursor.execute("ALTER TABLE tasks ADD COLUMN generation_started_at TEXT NULL")
+ dprint("SQLite Migration: 'generation_started_at' column added.")
+ else:
+ dprint("SQLite Migration: 'generation_started_at' column already exists.")
+
+ # --- Add generation_processed_at column if not exists ---
+ generation_processed_at_column_exists = 'generation_processed_at' in columns
+ if not generation_processed_at_column_exists:
+ dprint("SQLite Migration: 'generation_processed_at' column not found. Adding it.")
+ cursor.execute("ALTER TABLE tasks ADD COLUMN generation_processed_at TEXT NULL")
+ dprint("SQLite Migration: 'generation_processed_at' column added.")
+ else:
+ dprint("SQLite Migration: 'generation_processed_at' column already exists.")
+
+ # Populate task_type from params if it's NULL (for old rows or newly added column)
+ dprint("SQLite Migration: Attempting to populate NULL 'task_type' from 'params' JSON...")
+ cursor.execute("SELECT id, params FROM tasks WHERE task_type IS NULL")
+ rows_to_migrate = cursor.fetchall()
+
+ migrated_count = 0
+ for task_id, params_json_str in rows_to_migrate:
+ try:
+ params_dict = json.loads(params_json_str)
+ # Attempt to get task_type from common old locations within params
+ # The user might need to adjust these keys if their old storage was different
+ old_task_type = params_dict.get("task_type") # Most likely if it was in params
+
+ if old_task_type:
+ dprint(f"SQLite Migration: Found task_type '{old_task_type}' in params for task_id {task_id}. Updating row.")
+ cursor.execute("UPDATE tasks SET task_type = ? WHERE id = ?", (old_task_type, task_id))
+ migrated_count += 1
+ else:
+ # If task_type is not in params, it might be inferred from 'model' or other fields
+ # For instance, if 'model' field implied the task type for older tasks.
+ # This part is highly dependent on previous conventions.
+ # As a simple default, if not found, it will remain NULL unless a default is set.
+ # For 'travel_between_images' and 'different_perspective', these are typically set by steerable_motion.py
+ # and wouldn't exist as 'task_type' inside params for headless.py's default processing.
+ # Headless tasks like 'generate_openpose' *did* use task_type in params.
+ dprint(f"SQLite Migration: No 'task_type' key in params for task_id {task_id}. It will remain NULL or needs manual/specific migration logic if it was inferred differently.")
+ except json.JSONDecodeError:
+ dprint(f"SQLite Migration: Could not parse params JSON for task_id {task_id}. Skipping 'task_type' population for this row.")
+ except Exception as e_row:
+ dprint(f"SQLite Migration: Error processing row for task_id {task_id}: {e_row}")
+
+ if migrated_count > 0:
+ conn.commit()
+ dprint(f"SQLite Migration: Populated 'task_type' for {migrated_count} rows from params.")
+
+ # Default remaining NULL task_types for old standard tasks
+ # This ensures rows that didn't have an explicit 'task_type' in their params (e.g. old default WGP tasks)
+ # get a value, respecting the NOT NULL constraint if the table is new or fully validated.
+ default_task_type_for_old_rows = "standard_wgp_task"
+ cursor.execute(
+ f"UPDATE tasks SET task_type = ? WHERE task_type IS NULL",
+ (default_task_type_for_old_rows,)
+ )
+ updated_to_default_count = cursor.rowcount
+ if updated_to_default_count > 0:
+ conn.commit()
+ dprint(f"SQLite Migration: Updated {updated_to_default_count} older rows with NULL task_type to default '{default_task_type_for_old_rows}'.")
+
+ dprint("SQLite Migration: Schema check and population attempt complete.")
+ return True
+
+ execute_sqlite_with_retry(db_path_str, migration_operations)
+
+ except Exception as e:
+ print(f"[ERROR] SQLite Migration: Failed to migrate schema for {db_path_str}: {e}")
+ traceback.print_exc()
+ # Depending on severity, you might want to sys.exit(1)
+
+def _init_db_sqlite(db_path_str: str):
+ """Initialize the SQLite database with proper error handling"""
+ def _init_operation(conn):
+ cursor = conn.cursor()
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS tasks (
+ id TEXT PRIMARY KEY,
+ task_type TEXT NOT NULL,
+ params TEXT NOT NULL,
+ status TEXT NOT NULL DEFAULT 'Queued',
+ dependant_on TEXT NULL,
+ output_location TEXT NULL,
+ created_at TEXT NOT NULL,
+ updated_at TEXT NULL,
+ generation_started_at TEXT NULL,
+ generation_processed_at TEXT NULL,
+ project_id TEXT NOT NULL,
+ FOREIGN KEY (project_id) REFERENCES projects(id)
+ )
+ """)
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_status_created ON tasks(status, created_at)")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_dependant_on ON tasks(dependant_on)")
+ return True
+
+ try:
+ execute_sqlite_with_retry(db_path_str, _init_operation)
+ print(f"SQLite database initialized: {db_path_str}")
+ except Exception as e:
+ print(f"Failed to initialize SQLite database: {e}")
+ sys.exit(1)
+
+def get_oldest_queued_task_sqlite(db_path_str: str):
+ """Get the oldest queued task with proper error handling"""
+ def _get_operation(conn):
+ cursor = conn.cursor()
+ # Modified to select tasks with status 'Queued' only
+ # Also fetch project_id
+ sql_query = f"""
+ SELECT t.id, t.params, t.task_type, t.project_id
+ FROM tasks AS t
+ LEFT JOIN tasks AS d ON d.id = t.dependant_on
+ WHERE t.status = ?
+ AND (t.dependant_on IS NULL OR d.status = ?)
+ ORDER BY t.created_at ASC
+ LIMIT 1
+ """
+ query_params = (STATUS_QUEUED, STATUS_COMPLETE)
+
+ cursor.execute(sql_query, query_params)
+ task_row = cursor.fetchone()
+ if task_row:
+ task_id = task_row[0]
+ dprint(f"SQLite: Fetched raw task_row: {task_row}")
+
+ # Update status to IN_PROGRESS and set generation_started_at
+ current_utc_iso_ts = datetime.datetime.utcnow().isoformat() + "Z"
+ cursor.execute("""
+ UPDATE tasks
+ SET status = ?,
+ updated_at = ?,
+ generation_started_at = ?
+ WHERE id = ?
+ """, (STATUS_IN_PROGRESS, current_utc_iso_ts, current_utc_iso_ts, task_id))
+
+ return {"task_id": task_id, "params": json.loads(task_row[1]), "task_type": task_row[2], "project_id": task_row[3]}
+ return None
+
+ try:
+ return execute_sqlite_with_retry(db_path_str, _get_operation)
+ except Exception as e:
+ print(f"Error getting oldest queued task: {e}")
+ return None
+
+def update_task_status_sqlite(db_path_str: str, task_id: str, status: str, output_location_val: str | None = None):
+ """Updates a task's status and updated_at timestamp with proper error handling"""
+ def _update_operation(conn, task_id, status, output_location_val):
+ cursor = conn.cursor()
+
+ if status == STATUS_COMPLETE and output_location_val is not None:
+ # Step 1: Update output_location and updated_at for the location change
+ current_utc_iso_ts_loc_update = datetime.datetime.utcnow().isoformat() + "Z"
+ dprint(f"SQLite Update (Split Step 1): Updating output_location for {task_id} to {output_location_val}")
+ cursor.execute("UPDATE tasks SET output_location = ?, updated_at = ? WHERE id = ?",
+ (output_location_val, current_utc_iso_ts_loc_update, task_id))
+ conn.commit() # Explicitly commit the output_location update
+
+ # Step 2: Update status, updated_at, and generation_processed_at for the completion
+ current_utc_iso_ts_status_update = datetime.datetime.utcnow().isoformat() + "Z"
+ dprint(f"SQLite Update (Split Step 2): Updating status for {task_id} to {status}")
+ cursor.execute("UPDATE tasks SET status = ?, updated_at = ?, generation_processed_at = ? WHERE id = ?",
+ (status, current_utc_iso_ts_status_update, current_utc_iso_ts_status_update, task_id))
+ # The final commit for this status update will be handled by execute_sqlite_with_retry
+
+ elif status == STATUS_FAILED and output_location_val is not None: # output_location_val is error message here
+ current_utc_iso_ts_fail_update = datetime.datetime.utcnow().isoformat() + "Z"
+ dprint(f"SQLite Update (Single): Updating status to FAILED and output_location (error msg) for {task_id}")
+ cursor.execute("UPDATE tasks SET status = ?, updated_at = ?, output_location = ? WHERE id = ?",
+ (status, current_utc_iso_ts_fail_update, output_location_val, task_id))
+
+ else: # For "In Progress" or other statuses, or if output_location_val is None
+ current_utc_iso_ts_progress_update = datetime.datetime.utcnow().isoformat() + "Z"
+ dprint(f"SQLite Update (Single): Updating status for {task_id} to {status} (output_location_val: {output_location_val})")
+ # If output_location_val is None even for COMPLETE or FAILED, it won't be set here.
+ # This branch primarily handles IN_PROGRESS or status changes where output_location is not part of the update.
+ # If status is COMPLETE/FAILED and output_location_val is None, only status and updated_at change.
+ if output_location_val is not None and status in [STATUS_COMPLETE, STATUS_FAILED]:
+ # This case should ideally be caught by the specific branches above,
+ # but as a safeguard if logic changes:
+ dprint(f"SQLite Update (Single with output_location): Updating status, output_location for {task_id}")
+ cursor.execute("UPDATE tasks SET status = ?, updated_at = ?, output_location = ? WHERE id = ?",
+ (status, current_utc_iso_ts_progress_update, output_location_val, task_id))
+ else:
+ cursor.execute("UPDATE tasks SET status = ?, updated_at = ? WHERE id = ?",
+ (status, current_utc_iso_ts_progress_update, task_id))
+ return True
+
+ try:
+ execute_sqlite_with_retry(db_path_str, _update_operation, task_id, status, output_location_val)
+ dprint(f"SQLite: Updated status of task {task_id} to {status}. Output: {output_location_val if output_location_val else 'N/A'}")
+ except Exception as e:
+ print(f"Error updating task status for {task_id}: {e}")
+ # Don't raise here to avoid crashing the main loop
+
+def init_db():
+ """Initializes the database, dispatching to the correct implementation."""
+ if DB_TYPE == "supabase":
+ return init_db_supabase()
+ else:
+ return _init_db_sqlite(SQLITE_DB_PATH)
+
+def get_oldest_queued_task():
+ """Gets the oldest queued task, dispatching to the correct implementation."""
+ if DB_TYPE == "supabase":
+ return get_oldest_queued_task_supabase()
+ else:
+ return get_oldest_queued_task_sqlite(SQLITE_DB_PATH)
+
+def update_task_status(task_id: str, status: str, output_location: str | None = None):
+ """Updates a task's status, dispatching to the correct implementation."""
+ print(f"[UPDATE_TASK_STATUS_DEBUG] Called with:")
+ print(f"[UPDATE_TASK_STATUS_DEBUG] task_id: '{task_id}'")
+ print(f"[UPDATE_TASK_STATUS_DEBUG] status: '{status}'")
+ print(f"[UPDATE_TASK_STATUS_DEBUG] output_location: '{output_location}'")
+ print(f"[UPDATE_TASK_STATUS_DEBUG] DB_TYPE: '{DB_TYPE}'")
+
+ try:
+ if DB_TYPE == "supabase":
+ print(f"[UPDATE_TASK_STATUS_DEBUG] Dispatching to update_task_status_supabase")
+ result = update_task_status_supabase(task_id, status, output_location)
+ print(f"[UPDATE_TASK_STATUS_DEBUG] update_task_status_supabase completed successfully")
+ return result
+ else:
+ print(f"[UPDATE_TASK_STATUS_DEBUG] Dispatching to update_task_status_sqlite")
+ result = update_task_status_sqlite(SQLITE_DB_PATH, task_id, status, output_location)
+ print(f"[UPDATE_TASK_STATUS_DEBUG] update_task_status_sqlite completed successfully")
+ return result
+ except Exception as e:
+ print(f"[UPDATE_TASK_STATUS_DEBUG] ❌ Exception in update_task_status: {e}")
+ print(f"[UPDATE_TASK_STATUS_DEBUG] Exception type: {type(e).__name__}")
+ traceback.print_exc()
+ raise
+
+def init_db_supabase(): # Renamed from init_db_postgres
+ """Check if the Supabase tasks table exists (assuming it's already set up)."""
+ if not SUPABASE_CLIENT:
+ print("[ERROR] Supabase client not initialized. Cannot check database table.")
+ sys.exit(1)
+ try:
+ # Simply check if the tasks table exists by querying it
+ # Since the table already exists, we don't need to create it
+ result = SUPABASE_CLIENT.table(PG_TABLE_NAME).select("count", count="exact").limit(1).execute()
+ print(f"Supabase: Table '{PG_TABLE_NAME}' exists and accessible (count: {result.count})")
+ return True
+ except Exception as e:
+ print(f"[ERROR] Supabase table check failed: {e}")
+ # Don't exit - the table might exist but have different permissions
+ # Let the actual operations try and fail gracefully
+ return False
+
+def get_oldest_queued_task_supabase(worker_id: str = None): # Renamed from get_oldest_queued_task_postgres
+ """Fetches the oldest task via Supabase Edge Function only."""
+ if not SUPABASE_CLIENT:
+ print("[ERROR] Supabase client not initialized. Cannot get task.")
+ return None
+
+ # Use provided worker_id or use the specific GPU worker ID
+ if not worker_id:
+ worker_id = "gpu-20250723_221138-afa8403b"
+ dprint(f"DEBUG: No worker_id provided, using default GPU worker: {worker_id}")
+ else:
+ dprint(f"DEBUG: Using provided worker_id: {worker_id}")
+
+ # Use Edge Function exclusively
+ edge_url = (
+ SUPABASE_EDGE_CLAIM_TASK_URL
+ or os.getenv('SUPABASE_EDGE_CLAIM_TASK_URL')
+ or (f"{SUPABASE_URL.rstrip('/')}/functions/v1/claim-next-task" if SUPABASE_URL else None)
+ )
+
+ if edge_url and SUPABASE_ACCESS_TOKEN:
+ try:
+ dprint(f"DEBUG get_oldest_queued_task_supabase: Calling Edge Function at {edge_url}")
+ dprint(f"DEBUG: Using worker_id: {worker_id}")
+
+ headers = {
+ 'Content-Type': 'application/json',
+ 'Authorization': f'Bearer {SUPABASE_ACCESS_TOKEN}'
+ }
+
+ # Pass worker_id in the request body for edge function to use
+ payload = {"worker_id": worker_id}
+
+ resp = httpx.post(edge_url, json=payload, headers=headers, timeout=15)
+ dprint(f"Edge Function response status: {resp.status_code}")
+
+ if resp.status_code == 200:
+ task_data = resp.json()
+ dprint(f"Edge Function claimed task: {task_data}")
+ return task_data # Already in the expected format
+ elif resp.status_code == 204:
+ dprint("Edge Function: No queued tasks available")
+ return None
+ else:
+ dprint(f"Edge Function returned {resp.status_code}: {resp.text}")
+ return None
+ except Exception as e_edge:
+ dprint(f"Edge Function call failed: {e_edge}")
+ return None
+ else:
+ dprint("ERROR: No edge function URL or access token available for task claiming")
+ return None
+
+def update_task_status_supabase(task_id_str, status_str, output_location_val=None): # Renamed from update_task_status_postgres
+ """Updates a task's status via Supabase Edge Functions only."""
+ dprint(f"[DEBUG] update_task_status_supabase called: task_id={task_id_str}, status={status_str}, output_location={output_location_val}")
+
+ if not SUPABASE_CLIENT:
+ print("[ERROR] Supabase client not initialized. Cannot update task status.")
+ return
+
+ # --- Use edge functions for ALL status updates ---
+ if status_str == STATUS_COMPLETE and output_location_val is not None:
+ # Use complete-task edge function for completion with file
+ edge_url = (
+ SUPABASE_EDGE_COMPLETE_TASK_URL
+ or (os.getenv("SUPABASE_EDGE_COMPLETE_TASK_URL") or None)
+ or (f"{SUPABASE_URL.rstrip('/')}/functions/v1/complete-task" if SUPABASE_URL else None)
+ )
+
+ if not edge_url:
+ print(f"[ERROR] No complete-task edge function URL available")
+ return
+
+ try:
+ # Check if output_location_val is a local file path
+ output_path = Path(output_location_val)
+
+ if output_path.exists() and output_path.is_file():
+ # Read the file and encode as base64
+ import base64
+ with open(output_path, 'rb') as f:
+ file_data = base64.b64encode(f.read()).decode('utf-8')
+
+ headers = {"Content-Type": "application/json"}
+ if SUPABASE_ACCESS_TOKEN:
+ headers["Authorization"] = f"Bearer {SUPABASE_ACCESS_TOKEN}"
+
+ payload = {
+ "task_id": task_id_str,
+ "file_data": file_data,
+ "filename": output_path.name
+ }
+ dprint(f"[DEBUG] Calling complete-task Edge Function with file upload for task {task_id_str}")
+ resp = httpx.post(edge_url, json=payload, headers=headers, timeout=60)
+
+ if resp.status_code == 200:
+ dprint(f"[DEBUG] Edge function SUCCESS for task {task_id_str} → status COMPLETE with file upload")
+ return
+ else:
+ error_msg = f"complete-task edge function failed: {resp.status_code} - {resp.text}"
+ print(f"[ERROR] {error_msg}")
+ # Use update-task-status edge function to mark as failed
+ _mark_task_failed_via_edge_function(task_id_str, f"Upload failed: {error_msg}")
+ return
+ else:
+ # Not a local file, treat as URL
+ payload = {"task_id": task_id_str, "output_location": output_location_val}
+
+ headers = {"Content-Type": "application/json"}
+ if SUPABASE_ACCESS_TOKEN:
+ headers["Authorization"] = f"Bearer {SUPABASE_ACCESS_TOKEN}"
+
+ resp = httpx.post(edge_url, json=payload, headers=headers, timeout=30)
+
+ if resp.status_code == 200:
+ dprint(f"[DEBUG] Edge function SUCCESS for task {task_id_str} → status COMPLETE")
+ return
+ else:
+ error_msg = f"complete-task edge function failed: {resp.status_code} - {resp.text}"
+ print(f"[ERROR] {error_msg}")
+ # Use update-task-status edge function to mark as failed
+ _mark_task_failed_via_edge_function(task_id_str, f"Completion failed: {error_msg}")
+ return
+ except Exception as e_edge:
+ print(f"[ERROR] complete-task edge function exception: {e_edge}")
+ return
+ else:
+ # Use update-task-status edge function for all other status updates
+ edge_url = (
+ os.getenv("SUPABASE_EDGE_UPDATE_TASK_URL")
+ or (f"{SUPABASE_URL.rstrip('/')}/functions/v1/update-task-status" if SUPABASE_URL else None)
+ )
+
+ if not edge_url:
+ print(f"[ERROR] No update-task-status edge function URL available")
+ return
+
+ try:
+ headers = {"Content-Type": "application/json"}
+ if SUPABASE_ACCESS_TOKEN:
+ headers["Authorization"] = f"Bearer {SUPABASE_ACCESS_TOKEN}"
+
+ payload = {
+ "task_id": task_id_str,
+ "status": status_str
+ }
+
+ if output_location_val:
+ payload["output_location"] = output_location_val
+
+ dprint(f"[DEBUG] Calling update-task-status Edge Function for task {task_id_str} → {status_str}")
+ resp = httpx.post(edge_url, json=payload, headers=headers, timeout=30)
+
+ if resp.status_code == 200:
+ dprint(f"[DEBUG] Edge function SUCCESS for task {task_id_str} → status {status_str}")
+ return
+ else:
+ print(f"[ERROR] update-task-status edge function failed: {resp.status_code} - {resp.text}")
+ return
+
+ except Exception as e:
+ print(f"[ERROR] update-task-status edge function exception: {e}")
+ return
+
+def _migrate_supabase_schema():
+ """Legacy migration function - no longer used. Edge Function architecture complete."""
+ dprint("Supabase Migration: Migration to Edge Functions complete. Schema migrations handled externally.")
+ return # No-op - migrations complete
+
+def _run_db_migrations():
+ """Runs database migrations based on the configured DB_TYPE."""
+ dprint(f"DB Migrations: Running for DB_TYPE: {DB_TYPE}")
+ if DB_TYPE == "sqlite":
+ if SQLITE_DB_PATH:
+ _migrate_sqlite_schema(SQLITE_DB_PATH)
+ else:
+ print("[ERROR] DB Migration: SQLITE_DB_PATH not set. Skipping SQLite migration.")
+ elif DB_TYPE == "supabase":
+ # The Supabase schema is managed externally. Edge Function architecture complete.
+ dprint("DB Migrations: Skipping Supabase migrations (table assumed to exist).")
+ return
+ else:
+ dprint(f"DB Migrations: No migration logic for DB_TYPE '{DB_TYPE}'. Skipping migrations.")
+
+def add_task_to_db(task_payload: dict, task_type_str: str, dependant_on: str | None = None, db_path: str | None = None):
+ """
+ Adds a new task to the database, dispatching to SQLite or Supabase.
+ The `db_path` argument is for legacy SQLite compatibility and is ignored for Supabase.
+ """
+ task_id = task_payload.get("task_id")
+ if not task_id:
+ raise ValueError("task_id must be present in the task_payload.")
+
+ # Shared logic: Sanitize payload and get project_id
+ params_for_db = task_payload.copy()
+ params_for_db.pop("task_type", None) # Ensure task_type is not duplicated in params
+ params_json_str = json.dumps(params_for_db)
+ project_id = task_payload.get("project_id", "default_project_id")
+
+ if DB_TYPE == "supabase":
+ # Use Edge Function exclusively
+
+ # Build Edge URL – env var override > global constant > default pattern
+ edge_url = (
+ SUPABASE_EDGE_CREATE_TASK_URL # may be set at runtime
+ if "SUPABASE_EDGE_CREATE_TASK_URL" in globals() else None
+ ) or (os.getenv("SUPABASE_EDGE_CREATE_TASK_URL") or None) or (
+ f"{SUPABASE_URL.rstrip('/')}/functions/v1/create-task" if SUPABASE_URL else None
+ )
+
+ if not edge_url:
+ raise ValueError("Edge Function URL for create-task is not configured")
+
+ headers = {"Content-Type": "application/json"}
+ if SUPABASE_ACCESS_TOKEN:
+ headers["Authorization"] = f"Bearer {SUPABASE_ACCESS_TOKEN}"
+
+ payload_edge = {
+ "task_id": task_id,
+ "params": params_for_db, # pass JSON directly
+ "task_type": task_type_str,
+ "project_id": project_id,
+ "dependant_on": dependant_on,
+ }
+
+ dprint(f"Supabase Edge call >>> POST {edge_url} payload={str(payload_edge)[:120]}…")
+
+ try:
+ resp = httpx.post(edge_url, json=payload_edge, headers=headers, timeout=30)
+
+ if resp.status_code == 200:
+ print(f"Task {task_id} (Type: {task_type_str}) queued via Edge Function.")
+ return
+ else:
+ error_msg = f"Edge Function create-task failed: {resp.status_code} - {resp.text}"
+ print(f"[ERROR] {error_msg}")
+ raise RuntimeError(error_msg)
+
+ except httpx.RequestError as e:
+ error_msg = f"Edge Function create-task request failed: {e}"
+ print(f"[ERROR] {error_msg}")
+ raise RuntimeError(error_msg)
+
+ else: # Default to SQLite
+ db_to_use = db_path if db_path else SQLITE_DB_PATH
+ if not db_to_use:
+ raise ValueError("SQLite DB path is not configured.")
+
+ def _add_op(conn):
+ cursor = conn.cursor()
+ current_timestamp = datetime.datetime.utcnow().isoformat() + "Z"
+ cursor.execute(
+ f"INSERT INTO tasks (id, params, task_type, status, created_at, project_id, dependant_on) VALUES (?, ?, ?, ?, ?, ?, ?)",
+ (
+ task_id,
+ params_json_str,
+ task_type_str,
+ STATUS_QUEUED,
+ current_timestamp,
+ project_id,
+ dependant_on,
+ ),
+ )
+ try:
+ execute_sqlite_with_retry(db_to_use, _add_op)
+ print(f"Task {task_id} (Type: {task_type_str}) added to SQLite database {db_to_use}.")
+ except Exception as e:
+ print(f"SQLite error when adding task {task_id} (Type: {task_type_str}): {e}")
+ raise
+
+def poll_task_status(task_id: str, poll_interval_seconds: int = 10, timeout_seconds: int = 1800, db_path: str | None = None) -> str | None:
+ """
+ Polls the DB for task completion and returns the output_location.
+ Dispatches to SQLite or Supabase. `db_path` is for legacy SQLite calls.
+ """
+ print(f"Polling for completion of task {task_id} (timeout: {timeout_seconds}s)...")
+ start_time = time.time()
+ last_status_print_time = 0
+
+ while True:
+ current_time = time.time()
+ if current_time - start_time > timeout_seconds:
+ print(f"Error: Timeout polling for task {task_id} after {timeout_seconds} seconds.")
+ return None
+
+ status = None
+ output_location = None
+
+ if DB_TYPE == "supabase":
+ if not SUPABASE_CLIENT:
+ print("[ERROR] Supabase client not initialized. Cannot poll status.")
+ time.sleep(poll_interval_seconds)
+ continue
+ try:
+ # Direct table query for polling status
+ resp = SUPABASE_CLIENT.table(PG_TABLE_NAME).select("status, output_location").eq("id", task_id).single().execute()
+ if resp.data:
+ status = resp.data.get("status")
+ output_location = resp.data.get("output_location")
+ except Exception as e:
+ print(f"Supabase error while polling task {task_id}: {e}. Retrying...")
+ else: # SQLite
+ db_to_use = db_path if db_path else SQLITE_DB_PATH
+ if not db_to_use:
+ raise ValueError("SQLite DB path is not configured for polling.")
+
+ def _poll_op(conn):
+ conn.row_factory = sqlite3.Row
+ cursor = conn.cursor()
+ cursor.execute(f"SELECT status, output_location FROM tasks WHERE id = ?", (task_id,))
+ return cursor.fetchone()
+ try:
+ row = execute_sqlite_with_retry(db_to_use, _poll_op)
+ if row:
+ status = row["status"]
+ output_location = row["output_location"]
+ except Exception as e:
+ print(f"SQLite error while polling task {task_id}: {e}. Retrying...")
+
+ if status:
+ if current_time - last_status_print_time > poll_interval_seconds * 2 :
+ print(f"Task {task_id}: Status = {status} (Output: {output_location if output_location else 'N/A'})")
+ last_status_print_time = current_time
+
+ if status == STATUS_COMPLETE:
+ if output_location:
+ print(f"Task {task_id} completed successfully. Output: {output_location}")
+ return output_location
+ else:
+ print(f"Error: Task {task_id} is COMPLETE but output_location is missing. Assuming failure.")
+ return None
+ elif status == STATUS_FAILED:
+ print(f"Error: Task {task_id} failed. Error details: {output_location}")
+ return None
+ elif status not in [STATUS_QUEUED, STATUS_IN_PROGRESS]:
+ print(f"Warning: Task {task_id} has unknown status '{status}'. Treating as error.")
+ return None
+ else:
+ if current_time - last_status_print_time > poll_interval_seconds * 2 :
+ print(f"Task {task_id}: Not found in DB yet or status pending...")
+ last_status_print_time = current_time
+
+ time.sleep(poll_interval_seconds)
+
+# Helper to query DB for a specific task's output (needed by segment handler)
+def get_task_output_location_from_db(task_id_to_find: str) -> str | None:
+ dprint(f"Querying DB for output location of task: {task_id_to_find}")
+ if DB_TYPE == "sqlite":
+ def _get_op(conn):
+ cursor = conn.cursor()
+ # Ensure we only get tasks that are actually complete with an output
+ cursor.execute("SELECT output_location FROM tasks WHERE id = ? AND status = ? AND output_location IS NOT NULL",
+ (task_id_to_find, STATUS_COMPLETE))
+ row = cursor.fetchone()
+ return row[0] if row else None
+ try:
+ return execute_sqlite_with_retry(SQLITE_DB_PATH, _get_op)
+ except Exception as e:
+ print(f"Error querying SQLite for task output {task_id_to_find}: {e}")
+ return None
+ elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
+ try:
+ response = SUPABASE_CLIENT.table(PG_TABLE_NAME)\
+ .select("output_location, status")\
+ .eq("id", task_id_to_find)\
+ .single()\
+ .execute()
+
+ if response.data:
+ task_details = response.data
+ if task_details.get("status") == STATUS_COMPLETE and task_details.get("output_location"):
+ return task_details.get("output_location")
+ else:
+ dprint(f"Task {task_id_to_find} found but not complete or no output_location. Status: {task_details.get('status')}")
+ return None
+ else:
+ dprint(f"Task {task_id_to_find} not found in Supabase.")
+ return None
+ except Exception as e:
+ print(f"Error querying Supabase for task output {task_id_to_find}: {e}")
+ traceback.print_exc()
+ return None
+ dprint(f"DB type {DB_TYPE} not supported or client not init for get_task_output_location_from_db")
+ return None
+
+def get_task_params(task_id: str) -> str | None:
+ """Gets the raw params JSON string for a given task ID."""
+ if DB_TYPE == "sqlite":
+ def _get_op(conn):
+ cursor = conn.cursor()
+ cursor.execute("SELECT params FROM tasks WHERE id = ?", (task_id,))
+ row = cursor.fetchone()
+ return row[0] if row else None
+ return execute_sqlite_with_retry(SQLITE_DB_PATH, _get_op)
+ elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
+ try:
+ resp = SUPABASE_CLIENT.table(PG_TABLE_NAME).select("params").eq("id", task_id).single().execute()
+ if resp.data:
+ return resp.data.get("params")
+ return None
+ except Exception as e:
+ dprint(f"Error getting task params for {task_id} from Supabase: {e}")
+ return None
+ return None
+
+def get_task_dependency(task_id: str) -> str | None:
+ """Gets the dependency task ID for a given task ID."""
+ if DB_TYPE == "sqlite":
+ def _get_op(conn):
+ cursor = conn.cursor()
+ cursor.execute("SELECT dependant_on FROM tasks WHERE id = ?", (task_id,))
+ row = cursor.fetchone()
+ return row[0] if row else None
+ return execute_sqlite_with_retry(SQLITE_DB_PATH, _get_op)
+ elif DB_TYPE == "supabase" and SUPABASE_CLIENT:
+ try:
+ response = SUPABASE_CLIENT.table(PG_TABLE_NAME).select("dependant_on").eq("id", task_id).single().execute()
+ if response.data:
+ return response.data.get("dependant_on")
+ return None
+ except Exception as e_supabase_dep:
+ dprint(f"Error fetching dependant_on from Supabase for task {task_id}: {e_supabase_dep}")
+ return None
+ return None
+
+def get_predecessor_output_via_edge_function(task_id: str) -> tuple[str | None, str | None]:
+ """
+ Gets both the predecessor task ID and its output location in a single call using Edge Function.
+ Returns: (predecessor_id, output_location) or (None, None) if no dependency or error.
+
+ This replaces the separate calls to get_task_dependency() + get_task_output_location_from_db().
+ """
+ if DB_TYPE == "sqlite":
+ # For SQLite, fall back to separate calls since we don't have Edge Functions
+ predecessor_id = get_task_dependency(task_id)
+ if predecessor_id:
+ output_location = get_task_output_location_from_db(predecessor_id)
+ return predecessor_id, output_location
+ return None, None
+
+ elif DB_TYPE == "supabase" and SUPABASE_URL and SUPABASE_ACCESS_TOKEN:
+ # Use the new Edge Function for Supabase
+ edge_url = f"{SUPABASE_URL.rstrip('/')}/functions/v1/get-predecessor-output"
+
+ try:
+ dprint(f"Calling Edge Function: {edge_url} for task {task_id}")
+ headers = {
+ 'Content-Type': 'application/json',
+ 'Authorization': f'Bearer {SUPABASE_ACCESS_TOKEN}'
+ }
+
+ resp = httpx.post(edge_url, json={"task_id": task_id}, headers=headers, timeout=15)
+ dprint(f"Edge Function response status: {resp.status_code}")
+
+ if resp.status_code == 200:
+ result = resp.json()
+ dprint(f"Edge Function result: {result}")
+
+ if result is None:
+ # No dependency
+ return None, None
+
+ predecessor_id = result.get("predecessor_id")
+ output_location = result.get("output_location")
+ return predecessor_id, output_location
+
+ elif resp.status_code == 404:
+ dprint(f"Edge Function: Task {task_id} not found")
+ return None, None
+ else:
+ dprint(f"Edge Function returned {resp.status_code}: {resp.text}. Falling back to direct queries.")
+ # Fall back to separate calls
+ predecessor_id = get_task_dependency(task_id)
+ if predecessor_id:
+ output_location = get_task_output_location_from_db(predecessor_id)
+ return predecessor_id, output_location
+ return None, None
+
+ except Exception as e_edge:
+ dprint(f"Edge Function call failed: {e_edge}. Falling back to direct queries.")
+ # Fall back to separate calls
+ predecessor_id = get_task_dependency(task_id)
+ if predecessor_id:
+ output_location = get_task_output_location_from_db(predecessor_id)
+ return predecessor_id, output_location
+ return None, None
+
+ # If we can't use Edge Function, fall back to separate calls
+ predecessor_id = get_task_dependency(task_id)
+ if predecessor_id:
+ output_location = get_task_output_location_from_db(predecessor_id)
+ return predecessor_id, output_location
+ return None, None
+
+
+def get_completed_segment_outputs_for_stitch(run_id: str, project_id: str | None = None) -> list:
+ """Gets completed travel_segment outputs for a given run_id for stitching."""
+
+ if DB_TYPE == "sqlite":
+ def _get_op(conn):
+ cursor = conn.cursor()
+ sql_query = f"""
+ SELECT json_extract(t.params, '$.segment_index') AS segment_idx, t.output_location
+ FROM tasks t
+ INNER JOIN (
+ SELECT
+ json_extract(params, '$.segment_index') AS segment_idx,
+ MAX(created_at) AS max_created_at
+ FROM tasks
+ WHERE json_extract(params, '$.orchestrator_run_id') = ?
+ AND task_type = 'travel_segment'
+ AND status = ?
+ AND output_location IS NOT NULL
+ GROUP BY segment_idx
+ ) AS latest_tasks
+ ON json_extract(t.params, '$.segment_index') = latest_tasks.segment_idx
+ AND t.created_at = latest_tasks.max_created_at
+ WHERE json_extract(t.params, '$.orchestrator_run_id') = ?
+ AND t.task_type = 'travel_segment'
+ AND t.status = ?
+ ORDER BY CAST(json_extract(t.params, '$.segment_index') AS INTEGER) ASC
+ """
+ cursor.execute(sql_query, (run_id, STATUS_COMPLETE, run_id, STATUS_COMPLETE))
+ rows = cursor.fetchall()
+ return rows
+ return execute_sqlite_with_retry(SQLITE_DB_PATH, _get_op)
+ elif DB_TYPE == "supabase":
+ edge_url = f"{SUPABASE_URL.rstrip('/')}/functions/v1/get-completed-segments"
+ try:
+ dprint(f"Calling Edge Function: {edge_url} for run_id {run_id}, project_id {project_id}")
+ headers = {
+ 'Content-Type': 'application/json',
+ 'Authorization': f'Bearer {SUPABASE_ACCESS_TOKEN}'
+ }
+ payload = {"run_id": run_id}
+ if project_id:
+ payload["project_id"] = project_id
+
+ resp = httpx.post(edge_url, json=payload, headers=headers, timeout=15)
+ if resp.status_code == 200:
+ results = resp.json()
+ sorted_results = sorted(results, key=lambda x: x['segment_index'])
+ return [(r['segment_index'], r['output_location']) for r in sorted_results]
+ else:
+ dprint(f"Edge Function returned {resp.status_code}: {resp.text}. Falling back to direct query.")
+ except Exception as e:
+ dprint(f"Edge Function failed: {e}. Falling back to direct query.")
+
+ # Fallback to direct query
+ try:
+ # First, let's debug by getting ALL completed tasks to see what's there
+ debug_resp = SUPABASE_CLIENT.table(PG_TABLE_NAME).select("id, task_type, status, params, output_location")\
+ .eq("status", STATUS_COMPLETE).execute()
+
+ dprint(f"[DEBUG_STITCH] Looking for run_id: '{run_id}' (type: {type(run_id)})")
+ dprint(f"[DEBUG_STITCH] Total completed tasks in DB: {len(debug_resp.data) if debug_resp.data else 0}")
+
+ travel_segment_count = 0
+ matching_run_id_count = 0
+
+ if debug_resp.data:
+ for task in debug_resp.data:
+ task_type = task.get("task_type", "")
+ if task_type == "travel_segment":
+ travel_segment_count += 1
+ params_raw = task.get("params", {})
+ try:
+ params_obj = params_raw if isinstance(params_raw, dict) else json.loads(params_raw)
+ task_run_id = params_obj.get("orchestrator_run_id")
+ dprint(f"[DEBUG_STITCH] Found travel_segment task {task.get('id')}: orchestrator_run_id='{task_run_id}' (type: {type(task_run_id)}), segment_index={params_obj.get('segment_index')}, output_location={task.get('output_location', 'None')}")
+
+ if str(task_run_id) == str(run_id):
+ matching_run_id_count += 1
+ dprint(f"[DEBUG_STITCH] ✅ MATCH FOUND! Task {task.get('id')} matches run_id {run_id}")
+ except Exception as e_debug:
+ dprint(f"[DEBUG_STITCH] Error parsing params for task {task.get('id')}: {e_debug}")
+
+ dprint(f"[DEBUG_STITCH] Travel_segment tasks found: {travel_segment_count}")
+ dprint(f"[DEBUG_STITCH] Tasks matching run_id '{run_id}': {matching_run_id_count}")
+
+ # Now do the actual query
+ sel_resp = SUPABASE_CLIENT.table(PG_TABLE_NAME).select("params, output_location")\
+ .eq("task_type", "travel_segment").eq("status", STATUS_COMPLETE).execute()
+
+ results = []
+ if sel_resp.data:
+ import json
+ for i, row in enumerate(sel_resp.data):
+ params_raw = row.get("params")
+ if params_raw is None:
+ continue
+ try:
+ params_obj = params_raw if isinstance(params_raw, dict) else json.loads(params_raw)
+ except Exception as e:
+ continue
+
+ row_run_id = params_obj.get("orchestrator_run_id")
+
+ # Use string comparison to handle type mismatches
+ if str(row_run_id) == str(run_id):
+ seg_idx = params_obj.get("segment_index")
+ output_loc = row.get("output_location")
+ results.append((seg_idx, output_loc))
+ dprint(f"[DEBUG_STITCH] Added to results: segment_index={seg_idx}, output_location={output_loc}")
+
+ sorted_results = sorted(results, key=lambda x: x[0] if x[0] is not None else 0)
+ dprint(f"[DEBUG_STITCH] Final sorted results: {sorted_results}")
+ return sorted_results
+ except Exception as e_sel:
+ dprint(f"Stitch Supabase: Direct select failed: {e_sel}")
+ traceback.print_exc()
+ return []
+
+ return []
+
+def get_initial_task_counts() -> tuple[int, int] | None:
+ """Gets the total and queued task counts from the SQLite DB. Returns (None, None) on failure."""
+ if DB_TYPE != "sqlite": return None
+
+ def _get_counts_op(conn):
+ cursor = conn.cursor()
+ cursor.execute(f"SELECT COUNT(*) FROM {PG_TABLE_NAME}")
+ total_tasks = cursor.fetchone()[0]
+ cursor.execute(f"SELECT COUNT(*) FROM {PG_TABLE_NAME} WHERE status = ?", (STATUS_QUEUED,))
+ queued_tasks = cursor.fetchone()[0]
+ return total_tasks, queued_tasks
+
+ try:
+ return execute_sqlite_with_retry(SQLITE_DB_PATH, _get_counts_op)
+ except Exception as e:
+ print(f"SQLite error getting task counts: {e}")
+ return None
+
+def get_abs_path_from_db_path(db_path: str, dprint) -> Path | None:
+ """Helper to resolve a path from the DB (which might be relative) to a usable absolute path."""
+ if not db_path:
+ return None
+
+ resolved_path = None
+ if DB_TYPE == "sqlite" and SQLITE_DB_PATH and isinstance(db_path, str) and db_path.startswith("files/"):
+ sqlite_db_parent = Path(SQLITE_DB_PATH).resolve().parent
+ resolved_path = (sqlite_db_parent / "public" / db_path).resolve()
+ dprint(f"Resolved SQLite relative path '{db_path}' to '{resolved_path}'")
+ else:
+ # Path from DB is already absolute (Supabase) or a non-standard path
+ resolved_path = Path(db_path).resolve()
+
+ if resolved_path and resolved_path.exists():
+ return resolved_path
+ else:
+ dprint(f"Warning: Resolved path '{resolved_path}' from DB path '{db_path}' does not exist.")
+ return None
+
+def upload_to_supabase_storage(file_path, bucket_name="image_uploads", custom_path=None):
+ """Upload a file to Supabase storage with improved error handling."""
+ if not SUPABASE_CLIENT:
+ print("[ERROR] Supabase client not initialized. Cannot upload file.")
+ return None
+
+ try:
+ # Ensure file exists
+ if not os.path.exists(file_path):
+ print(f"[ERROR] File does not exist: {file_path}")
+ return None
+
+ file_name = os.path.basename(file_path)
+ storage_path = custom_path or f"system/{file_name}"
+
+ dprint(f"Uploading {file_path} to {bucket_name}/{storage_path}")
+
+ # Read file data
+ with open(file_path, 'rb') as f:
+ file_data = f.read()
+
+ # Upload to Supabase storage
+ res = SUPABASE_CLIENT.storage.from_(bucket_name).upload(storage_path, file_data)
+
+ # Handle the new UploadResponse API
+ if hasattr(res, 'error') and res.error:
+ print(f"Supabase upload error: {res.error}")
+ return None
+ elif hasattr(res, 'data') and res.data:
+ # Success case - get the public URL
+ public_url = SUPABASE_CLIENT.storage.from_(bucket_name).get_public_url(storage_path)
+ dprint(f"Supabase upload successful. Public URL: {public_url}")
+ return public_url
+ else:
+ # Try to get public URL anyway (upload might have succeeded)
+ try:
+ public_url = SUPABASE_CLIENT.storage.from_(bucket_name).get_public_url(storage_path)
+ dprint(f"Upload completed, public URL: {public_url}")
+ return public_url
+ except Exception as url_error:
+ print(f"[ERROR] Could not get public URL after upload: {url_error}")
+ return None
+
+ except Exception as e:
+ print(f"[ERROR] An exception occurred during Supabase upload: {e}")
+ return None
+
+def mark_task_failed_supabase(task_id_str, error_message):
+ """Marks a task as Failed with an error message using direct database update."""
+ dprint(f"Marking task {task_id_str} as Failed with message: {error_message}")
+ if not SUPABASE_CLIENT:
+ print("[ERROR] Supabase client not initialized. Cannot mark task failed.")
+ return
+
+ # Use the standard update function which now uses direct database updates for non-COMPLETE statuses
+ update_task_status_supabase(task_id_str, STATUS_FAILED, error_message)
\ No newline at end of file
diff --git a/sm_functions/__init__.py b/source/sm_functions/__init__.py
similarity index 65%
rename from sm_functions/__init__.py
rename to source/sm_functions/__init__.py
index 5aa533028..39f8bca2c 100644
--- a/sm_functions/__init__.py
+++ b/source/sm_functions/__init__.py
@@ -5,13 +5,11 @@
"""
# Task Handlers
-from .travel_between_images import run_travel_between_images_task
-from .different_pose import run_different_pose_task
# Common Utilities
# These are re-exported here for convenience, allowing task modules to import
# them directly from `sm_functions.common_utils` or just `sm_functions` if preferred.
-from .common_utils import (
+from ..common_utils import (
DEBUG_MODE, # Note: This is a global, set by steerable_motion.py
DEFAULT_DB_TABLE_NAME,
STATUS_QUEUED,
@@ -44,24 +42,26 @@
get_resized_frame,
draw_multiline_text,
generate_debug_summary_video,
- generate_different_pose_debug_video_summary,
+ generate_different_perspective_debug_video_summary,
extract_specific_frame_ffmpeg,
concatenate_videos_ffmpeg,
get_video_frame_count_and_fps,
- get_image_dimensions_pil
+ get_image_dimensions_pil,
+ create_mask_video_from_inactive_indices,
+ create_simple_first_frame_mask_video
)
# --- Make video_utils directly importable ---
-from .video_utils import (
- ease as crossfade_ease, # Renamed to avoid conflict if other eases are added
- _blend_linear,
- _blend_linear_sharp,
- cross_fade_overlap_frames,
- extract_frames_from_video,
- create_video_from_frames_list,
- _apply_saturation_to_video_ffmpeg,
- color_match_video_to_reference
-)
+# from ..video_utils import (
+# crossfade_ease, # Renamed to avoid conflict if other eases are added
+# _blend_linear,
+# _blend_linear_sharp,
+# cross_fade_overlap_frames,
+# extract_frames_from_video,
+# create_video_from_frames_list,
+# _apply_saturation_to_video_ffmpeg,
+# color_match_video_to_reference
+# )
__all__ = [
# common_utils exports
@@ -76,18 +76,18 @@
"DEFAULT_DB_TABLE_NAME",
"get_image_dimensions_pil",
"draw_multiline_text",
- "generate_different_pose_debug_video_summary",
+ "generate_different_perspective_debug_video_summary",
+ "create_mask_video_from_inactive_indices",
+ "create_simple_first_frame_mask_video",
# travel_between_images exports
- "run_travel_between_images_task",
- # different_pose exports
- "run_different_pose_task",
+ # different_perspective exports
# video_utils exports
- "crossfade_ease",
- "_blend_linear",
- "_blend_linear_sharp",
- "cross_fade_overlap_frames",
- "extract_frames_from_video",
- "create_video_from_frames_list",
- "_apply_saturation_to_video_ffmpeg",
- "color_match_video_to_reference"
+ # "crossfade_ease",
+ # "_blend_linear",
+ # "_blend_linear_sharp",
+ # "cross_fade_overlap_frames",
+ # "extract_frames_from_video",
+ # "create_video_from_frames_list",
+ # "_apply_saturation_to_video_ffmpeg",
+ # "color_match_video_to_reference"
]
\ No newline at end of file
diff --git a/source/sm_functions/different_perspective.py b/source/sm_functions/different_perspective.py
new file mode 100644
index 000000000..3823cb3fc
--- /dev/null
+++ b/source/sm_functions/different_perspective.py
@@ -0,0 +1,338 @@
+"""Different-perspective task handler."""
+
+import json
+import shutil
+import traceback
+from pathlib import Path
+
+import cv2 # pip install opencv-python
+from PIL import Image # pip install Pillow
+
+# Import from the new common_utils and db_operations modules
+from .. import db_operations as db_ops
+from ..common_utils import (
+ DEBUG_MODE, dprint, generate_unique_task_id, add_task_to_db, poll_task_status,
+ save_frame_from_video, create_pose_interpolated_guide_video,
+ generate_different_perspective_debug_video_summary,
+ parse_resolution as sm_parse_resolution,
+ create_simple_first_frame_mask_video,
+ prepare_output_path_with_upload,
+ upload_and_get_final_output_location
+)
+from ..video_utils import rife_interpolate_images_to_video # For depth guide interpolation
+
+def _handle_different_perspective_orchestrator_task(task_params_from_db: dict, main_output_dir_base: Path, orchestrator_task_id_str: str, dprint):
+ """
+ This is the entry point for a 'different_perspective' job.
+ It sets up the environment and enqueues all necessary child tasks with
+ database-level dependencies.
+ """
+ print(f"--- Orchestrating Task: Different Perspective (ID: {orchestrator_task_id_str}) ---")
+ dprint(f"Orchestrator Task Params: {task_params_from_db}")
+
+ try:
+ public_files_dir = Path.cwd() / "public" / "files"
+ public_files_dir.mkdir(parents=True, exist_ok=True)
+ main_output_dir_base = public_files_dir
+
+ run_id = generate_unique_task_id("dp_run_")
+ work_dir = main_output_dir_base / f"different_perspective_run_{run_id}"
+ work_dir.mkdir(parents=True, exist_ok=True)
+ print(f"Working directory for this 'different_perspective' run: {work_dir.resolve()}")
+
+ task_id_user_pose = generate_unique_task_id("dp_user_pose_")
+ task_id_t2i = generate_unique_task_id("dp_t2i_")
+ task_id_extract = generate_unique_task_id("dp_extract_")
+ task_id_t2i_pose = generate_unique_task_id("dp_t2i_pose_")
+ task_id_final_gen = generate_unique_task_id("dp_final_gen_")
+
+ perspective_type = task_params_from_db.get("perspective_type", "pose").lower()
+
+ orchestrator_payload = {
+ "run_id": run_id,
+ "orchestrator_task_id": orchestrator_task_id_str,
+ "work_dir": str(work_dir.resolve()),
+ "main_output_dir": str(main_output_dir_base.resolve()),
+ "original_params": task_params_from_db,
+ "perspective_type": perspective_type,
+ "task_ids": {
+ "user_persp": task_id_user_pose,
+ "t2i": task_id_t2i,
+ "extract": task_id_extract,
+ "t2i_persp": task_id_t2i_pose,
+ "final_gen": task_id_final_gen,
+ },
+ "debug_mode": task_params_from_db.get("debug_mode", False),
+ "skip_cleanup": task_params_from_db.get("skip_cleanup", False),
+ }
+
+ previous_task_id = None
+
+ payload_user_persp = {
+ "task_id": task_id_user_pose,
+ "input_image_path": task_params_from_db['input_image_path'],
+ "dp_orchestrator_payload": orchestrator_payload,
+ "output_dir": str(work_dir.resolve()),
+ }
+ user_gen_task_type = "generate_openpose" if perspective_type == "pose" else "generate_depth"
+ db_ops.add_task_to_db(payload_user_persp, user_gen_task_type, dependant_on=previous_task_id)
+ dprint(f"Orchestrator {orchestrator_task_id_str} enqueued {user_gen_task_type} ({task_id_user_pose})")
+ previous_task_id = task_id_user_pose
+
+ payload_t2i = {
+ "task_id": task_id_t2i,
+ "prompt": task_params_from_db.get("prompt"),
+ "model": task_params_from_db.get("model_name"),
+ "resolution": task_params_from_db.get("resolution"),
+ "frames": 1,
+ "seed": task_params_from_db.get("seed", -1),
+ "use_causvid_lora": task_params_from_db.get("use_causvid_lora", False),
+ "dp_orchestrator_payload": orchestrator_payload,
+ "output_dir": str(work_dir.resolve()),
+ }
+ db_ops.add_task_to_db(payload_t2i, "wgp", dependant_on=previous_task_id)
+ dprint(f"Orchestrator {orchestrator_task_id_str} enqueued t2i_gen ({task_id_t2i})")
+ previous_task_id = task_id_t2i
+
+ payload_extract = {
+ "task_id": task_id_extract,
+ "input_video_task_id": task_id_t2i,
+ "frame_index": 0,
+ "dp_orchestrator_payload": orchestrator_payload,
+ "output_dir": str(work_dir.resolve()),
+ }
+ db_ops.add_task_to_db(payload_extract, "extract_frame", dependant_on=previous_task_id)
+ dprint(f"Orchestrator {orchestrator_task_id_str} enqueued extract_frame ({task_id_extract})")
+ previous_task_id = task_id_extract
+
+ payload_t2i_persp = {
+ "task_id": task_id_t2i_pose,
+ "input_image_task_id": task_id_extract,
+ "dp_orchestrator_payload": orchestrator_payload,
+ "output_dir": str(work_dir.resolve()),
+ }
+ db_ops.add_task_to_db(payload_t2i_persp, user_gen_task_type, dependant_on=previous_task_id)
+ dprint(f"Orchestrator {orchestrator_task_id_str} enqueued {user_gen_task_type} ({task_id_t2i_pose})")
+ previous_task_id = task_id_t2i_pose
+
+ payload_final_gen = {
+ "task_id": task_id_final_gen,
+ "dp_orchestrator_payload": orchestrator_payload,
+ }
+ db_ops.add_task_to_db(payload_final_gen, "dp_final_gen", dependant_on=previous_task_id)
+ dprint(f"Orchestrator {orchestrator_task_id_str} enqueued dp_final_gen ({task_id_final_gen})")
+
+ return True, f"Successfully enqueued different_perspective job graph with run_id {run_id}."
+ except Exception as e:
+ error_msg = f"Different Perspective orchestration failed: {e}"
+ print(f"[ERROR] {error_msg}")
+ traceback.print_exc()
+ return False, error_msg
+
+
+# -----------------------------------------------------------------------------
+# Final generation step – run WGP inline (no DB queue / poll)
+# -----------------------------------------------------------------------------
+# We refactor this routine so that it behaves like the travel pipeline: the
+# heavy WGP call is executed synchronously via ``process_single_task`` instead
+# of being queued as a separate DB task and then polled for completion. This
+# avoids the single-worker dead-lock where the same process waits for a task it
+# is supposed to execute itself.
+
+# NOTE: The headless server will pass ``wgp_mod`` (imported Wan2GP module), the
+# current ``main_output_dir_base`` and a reference to its own
+# ``process_single_task`` so that we can leverage the existing wrapper.
+
+def _handle_dp_final_gen_task(
+ *,
+ wgp_mod,
+ main_output_dir_base: Path,
+ process_single_task, # recursive call helper supplied by headless.py
+ task_params_from_db: dict,
+ dprint,
+):
+ """
+ Handles the final step of the 'different_perspective' process. It gathers all
+ the required artifacts generated by previous tasks in the dependency graph,
+ creates the final guide video, runs the last generation, extracts the
+ final image, and performs cleanup.
+ """
+ payload = task_params_from_db
+ orchestrator_payload = payload.get("dp_orchestrator_payload")
+ if not orchestrator_payload:
+ return False, "Final Gen failed: 'dp_orchestrator_payload' not found.", None
+
+ task_ids = orchestrator_payload.get("task_ids", {})
+ work_dir = Path(orchestrator_payload["work_dir"])
+ original_params = orchestrator_payload["original_params"]
+ final_path_for_db = None
+
+ dprint("DP Final Gen: Starting final generation step.")
+
+ try:
+ perspective_type = orchestrator_payload.get("perspective_type", "pose").lower()
+
+ user_persp_id = task_ids["user_persp"]
+ t2i_persp_id = task_ids["t2i_persp"]
+
+ user_persp_path_db = db_ops.get_task_output_location_from_db(user_persp_id)
+ t2i_persp_path_db = db_ops.get_task_output_location_from_db(t2i_persp_id)
+
+ user_persp_image_path = db_ops.get_abs_path_from_db_path(user_persp_path_db, dprint)
+ t2i_persp_image_path = db_ops.get_abs_path_from_db_path(t2i_persp_path_db, dprint)
+
+ if not all([user_persp_image_path, t2i_persp_image_path, t2i_image_path]):
+ return False, "Could not resolve one or more required image paths from previous tasks.", None
+
+ print("\nDP Final Gen: Creating custom guide video…")
+ custom_guide_video_path = work_dir / f"{generate_unique_task_id('dp_custom_guide_')}.mp4"
+
+ if perspective_type == "pose":
+ create_pose_interpolated_guide_video(
+ output_video_path=custom_guide_video_path,
+ resolution=sm_parse_resolution(original_params.get("resolution")),
+ total_frames=original_params.get("output_video_frames", 16),
+ start_image_path=Path(original_params['input_image_path']),
+ end_image_path=t2i_image_path,
+ fps=original_params.get("fps_helpers", 16),
+ )
+ print(f"DP Final Gen: Pose guide video created: {custom_guide_video_path}")
+ else: # depth
+ # Use RIFE to interpolate between depth maps
+ try:
+ img_start = Image.open(user_persp_image_path).convert("RGB")
+ img_end = Image.open(t2i_persp_image_path).convert("RGB")
+
+ rife_success = rife_interpolate_images_to_video(
+ image1=img_start,
+ image2=img_end,
+ num_frames=original_params.get("output_video_frames", 16),
+ resolution_wh=sm_parse_resolution(original_params.get("resolution")),
+ output_path=custom_guide_video_path,
+ fps=original_params.get("fps_helpers", 16),
+ dprint_func=dprint,
+ )
+ if not rife_success:
+ return False, "Failed to create depth interpolated guide video via RIFE", None
+ print(f"DP Final Gen: Depth guide video created: {custom_guide_video_path}")
+ except Exception as e_gv:
+ print(f"[ERROR] DP Final Gen: Depth guide video creation failed: {e_gv}")
+ traceback.print_exc()
+ return False, f"Guide video creation failed: {e_gv}", None
+
+ # ------------------------------------------------------------------
+ # 2. Prepare WGP inline generation payload
+ # ------------------------------------------------------------------
+
+ final_video_task_id = generate_unique_task_id("dp_final_video_")
+
+ final_video_payload = {
+ "task_id": final_video_task_id,
+ "prompt": original_params.get("prompt"),
+ "model": original_params.get("model_name"),
+ "resolution": original_params.get("resolution"),
+ "frames": original_params.get("output_video_frames", 16),
+ "seed": original_params.get("seed", -1) + 1,
+ "video_guide_path": str(custom_guide_video_path.resolve()),
+ "image_refs_paths": [original_params['input_image_path']],
+ "image_prompt_type": "IV",
+ "use_causvid_lora": original_params.get("use_causvid_lora", False),
+ # ensure outputs stay inside the run work dir
+ "output_dir": str(work_dir.resolve()),
+ }
+
+ # Optionally create a mask video to freeze first frame
+ if original_params.get("use_mask_for_first_frame", True):
+ dprint("DP Final Gen: Creating mask video to preserve first frame...")
+ mask_video_path = work_dir / f"{generate_unique_task_id('dp_mask_')}.mp4"
+
+ created_mask = create_simple_first_frame_mask_video(
+ total_frames=original_params.get("output_video_frames", 16),
+ resolution_wh=sm_parse_resolution(original_params.get("resolution")),
+ output_path=mask_video_path,
+ fps=original_params.get("fps_helpers", 16),
+ task_id_for_logging=final_video_task_id,
+ dprint=dprint
+ )
+
+ if created_mask:
+ final_video_payload["video_mask_path"] = str(created_mask.resolve())
+ final_video_payload["video_prompt_type"] = "IVM" # Image + Video guide + Mask
+ dprint(f"DP Final Gen: Mask video created at {created_mask}")
+ else:
+ dprint("DP Final Gen: Warning - Failed to create mask video, proceeding without mask")
+
+ # ------------------------------------------------------------------
+ # 3. Execute WGP synchronously (no DB queue / polling)
+ # ------------------------------------------------------------------
+
+ print("\nDP Final Gen: Launching inline WGP generation for final video…")
+
+ generation_success, final_video_output_db = process_single_task(
+ wgp_mod,
+ final_video_payload,
+ main_output_dir_base,
+ "wgp",
+ project_id_for_task=original_params.get("project_id"),
+ image_download_dir=None,
+ apply_reward_lora=False,
+ colour_match_videos=False,
+ mask_active_frames=True,
+ )
+
+ if not generation_success:
+ return False, "Final video generation failed.", None
+
+ final_video_path = db_ops.get_abs_path_from_db_path(final_video_output_db, dprint)
+ if not final_video_path:
+ return False, f"Could not resolve final video path from '{final_video_output_db}'", None
+
+ print("\nDP Final Gen: Extracting final posed image...")
+
+ # Use prepare_output_path_with_upload for Supabase-compatible output handling
+ final_image_filename = f"final_posed_image_{orchestrator_payload['run_id']}.png"
+ final_posed_image_output_path, initial_db_location = prepare_output_path_with_upload(
+ task_id=payload.get("task_id", "dp_final_gen"),
+ filename=final_image_filename,
+ main_output_dir_base=Path(orchestrator_payload["main_output_dir"]),
+ dprint=dprint
+ )
+
+ if not save_frame_from_video(final_video_path, -1, final_posed_image_output_path, sm_parse_resolution(original_params.get("resolution"))):
+ return False, f"Failed to extract final posed image from {final_video_path}", None
+
+ # Handle Supabase upload (if configured) and get final location for DB
+ final_path_for_db = upload_and_get_final_output_location(
+ final_posed_image_output_path,
+ final_image_filename, # Pass only the filename to avoid redundant subfolder
+ initial_db_location,
+ dprint=dprint
+ )
+
+ print(f"Successfully completed 'different_perspective' task! Final image: {final_posed_image_output_path.resolve()} (DB location: {final_path_for_db})")
+
+ # Preserve intermediates when either the orchestrator payload says so *or*
+ # the headless server is running with --debug (exposed via db_ops.debug_mode).
+ if (not orchestrator_payload.get("skip_cleanup") and
+ not orchestrator_payload.get("debug_mode") and
+ not db_ops.debug_mode):
+ print(f"DP Final Gen: Cleaning up intermediate files in {work_dir}...")
+ try:
+ shutil.rmtree(work_dir)
+ print(f"Removed intermediate directory: {work_dir}")
+ except OSError as e_clean:
+ print(f"Error removing intermediate directory {work_dir}: {e_clean}")
+ else:
+ print(f"Skipping cleanup of intermediate files in {work_dir}.")
+
+ db_ops.update_task_status(orchestrator_payload['orchestrator_task_id'], db_ops.STATUS_COMPLETE, final_path_for_db)
+ dprint(f"DP Final Gen: Process complete. Final image at {final_path_for_db}")
+
+ return True, final_path_for_db
+
+ except Exception as e:
+ error_msg = f"DP Final Gen failed: {e}"
+ print(f"[ERROR] {error_msg}")
+ traceback.print_exc()
+ db_ops.update_task_status(orchestrator_payload['orchestrator_task_id'], db_ops.STATUS_FAILED, error_msg)
+ return False, error_msg
\ No newline at end of file
diff --git a/source/sm_functions/magic_edit.py b/source/sm_functions/magic_edit.py
new file mode 100644
index 000000000..e4edbc8be
--- /dev/null
+++ b/source/sm_functions/magic_edit.py
@@ -0,0 +1,203 @@
+"""Magic Edit functionality using Replicate API.
+
+This module provides the magic_edit task handler that processes images through
+the flux-kontext-apps/in-scene model on Replicate to generate scene variations.
+"""
+
+import os
+import traceback
+import tempfile
+import shutil
+from pathlib import Path
+import requests
+from typing import Tuple
+
+try:
+ import replicate
+except ImportError:
+ replicate = None
+
+from .. import db_operations as db_ops
+from ..common_utils import (
+ sm_get_unique_target_path,
+ download_image_if_url as sm_download_image_if_url,
+ prepare_output_path_with_upload,
+ upload_and_get_final_output_location,
+ report_orchestrator_failure
+)
+
+
+def _handle_magic_edit_task(
+ task_params_from_db: dict,
+ main_output_dir_base: Path,
+ task_id: str,
+ *,
+ dprint
+) -> Tuple[bool, str]:
+ """
+ Handle a magic_edit task by processing an image through Replicate's black-forest-labs/flux-kontext-dev-lora model.
+
+ Args:
+ task_params_from_db: Task parameters containing orchestrator_details
+ main_output_dir_base: Base directory for outputs
+ task_id: Unique task identifier
+ dprint: Debug print function
+
+ Returns:
+ Tuple of (success_bool, output_location_or_error_message)
+ """
+ print(f"[Task ID: {task_id}] Starting magic_edit task")
+
+ try:
+ # Check if replicate is available
+ if replicate is None:
+ msg = "Replicate library not installed. Run: pip install replicate"
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ report_orchestrator_failure(task_params_from_db, msg, dprint)
+ return False, msg
+
+ # Check for API token
+ api_token = os.getenv("REPLICATE_API_TOKEN")
+ if not api_token:
+ msg = "REPLICATE_API_TOKEN environment variable not set"
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ report_orchestrator_failure(task_params_from_db, msg, dprint)
+ return False, msg
+
+ # Extract orchestrator details
+ orchestrator_details = task_params_from_db.get("orchestrator_details", {})
+ if not orchestrator_details:
+ msg = "No orchestrator_details found in task parameters"
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ report_orchestrator_failure(task_params_from_db, msg, dprint)
+ return False, msg
+
+ # Required parameters
+ image_url = orchestrator_details.get("image_url")
+ prompt = orchestrator_details.get("prompt", "Make a shot in the same scene from a different angle")
+ resolution = orchestrator_details.get("resolution", "768x576")
+ seed = orchestrator_details.get("seed")
+ in_scene = orchestrator_details.get("in_scene", False) # Default to False
+
+ if not image_url:
+ msg = "image_url is required in orchestrator_details"
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ report_orchestrator_failure(task_params_from_db, msg, dprint)
+ return False, msg
+
+ dprint(f"Task {task_id}: Magic edit parameters - Image: {image_url}, Prompt: {prompt}, Resolution: {resolution}, In-Scene LoRA: {in_scene}")
+
+ # Create temporary directory for processing
+ temp_dir = Path(tempfile.mkdtemp(prefix=f"magic_edit_{task_id}_"))
+ dprint(f"Task {task_id}: Created temp directory: {temp_dir}")
+
+ try:
+ # Validate that we have a URL (Replicate expects URLs, not local files)
+ if not image_url.startswith(("http://", "https://")):
+ msg = f"image_url must be a valid HTTP/HTTPS URL, got: {image_url}"
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ return False, msg
+
+ dprint(f"Task {task_id}: Using image URL directly with Replicate: {image_url}")
+
+ # Prepare Replicate input
+ replicate_input = {
+ "prompt": prompt,
+ "input_image": image_url,
+ "aspect_ratio": "match_input_image", # Let Replicate match the input image aspect ratio
+ "lora_strength": 1,
+ "output_format": "webp",
+ "output_quality": 90
+ }
+
+ # Only add LoRA weights if in_scene is True
+ if in_scene:
+ replicate_input["lora_weights"] = "https://huggingface.co/peteromallet/Flux-Kontext-InScene/resolve/main/InScene-v1.0.safetensors"
+ dprint(f"Task {task_id}: Using InScene LoRA for scene-consistent generation")
+ else:
+ dprint(f"Task {task_id}: Using base Flux model without InScene LoRA")
+
+ # Add seed if specified
+ if seed is not None:
+ replicate_input["seed"] = int(seed)
+
+ print(f"[Task ID: {task_id}] Running Replicate black-forest-labs/flux-kontext-dev-lora model...")
+ dprint(f"Task {task_id}: Replicate input: {replicate_input}")
+
+ # Run the model
+ output = replicate.run(
+ "black-forest-labs/flux-kontext-dev-lora",
+ input=replicate_input
+ )
+
+ print(f"[Task ID: {task_id}] Replicate processing completed")
+ dprint(f"Task {task_id}: Replicate output type: {type(output)}")
+
+ # Download the result
+ output_image_path = temp_dir / f"magic_edit_{task_id}.webp"
+
+ if hasattr(output, 'read'):
+ # Output is a file-like object
+ with open(output_image_path, 'wb') as f:
+ f.write(output.read())
+ elif hasattr(output, 'url'):
+ # Output has a URL method
+ response = requests.get(output.url(), timeout=60)
+ response.raise_for_status()
+ with open(output_image_path, 'wb') as f:
+ f.write(response.content)
+ elif isinstance(output, str) and output.startswith(("http://", "https://")):
+ # Output is a URL string
+ response = requests.get(output, timeout=60)
+ response.raise_for_status()
+ with open(output_image_path, 'wb') as f:
+ f.write(response.content)
+ else:
+ msg = f"Unexpected output type from Replicate: {type(output)}"
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ return False, msg
+
+ if not output_image_path.exists() or output_image_path.stat().st_size == 0:
+ msg = "Failed to download result from Replicate"
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ return False, msg
+
+ print(f"[Task ID: {task_id}] Downloaded result image: {output_image_path.name} ({output_image_path.stat().st_size} bytes)")
+
+ # Prepare final output path
+ final_output_path, initial_db_location = prepare_output_path_with_upload(
+ task_id=task_id,
+ filename=output_image_path.name,
+ main_output_dir_base=main_output_dir_base,
+ dprint=dprint
+ )
+
+ # Move to final location
+ shutil.move(str(output_image_path), str(final_output_path))
+ dprint(f"Task {task_id}: Moved result to final location: {final_output_path}")
+
+ # Handle upload and get final DB location
+ final_db_location = upload_and_get_final_output_location(
+ final_output_path,
+ task_id,
+ initial_db_location,
+ dprint=dprint
+ )
+
+ print(f"[Task ID: {task_id}] Magic edit completed successfully: {final_output_path.resolve()}")
+ return True, final_db_location
+
+ finally:
+ # Cleanup temp directory
+ try:
+ shutil.rmtree(temp_dir)
+ dprint(f"Task {task_id}: Cleaned up temp directory: {temp_dir}")
+ except Exception as e_cleanup:
+ print(f"[WARNING Task ID: {task_id}] Failed to cleanup temp directory {temp_dir}: {e_cleanup}")
+
+ except Exception as e:
+ print(f"[ERROR Task ID: {task_id}] Magic edit task failed: {e}")
+ traceback.print_exc()
+ msg = f"Magic edit exception: {e}"
+ report_orchestrator_failure(task_params_from_db, msg, dprint)
+ return False, msg
\ No newline at end of file
diff --git a/source/sm_functions/single_image.py b/source/sm_functions/single_image.py
new file mode 100644
index 000000000..4e22e9e96
--- /dev/null
+++ b/source/sm_functions/single_image.py
@@ -0,0 +1,237 @@
+"""Single image generation task handler."""
+
+import json
+import tempfile
+import traceback
+from pathlib import Path
+
+# Import from the restructured modules
+from .. import db_operations as db_ops
+from ..common_utils import (
+ sm_get_unique_target_path,
+ download_image_if_url as sm_download_image_if_url,
+ load_pil_images as sm_load_pil_images,
+ parse_resolution as sm_parse_resolution,
+ build_task_state,
+ prepare_output_path,
+ process_additional_loras_shared, # New shared function
+ snap_resolution_to_model_grid, # New shared function
+ prepare_output_path_with_upload, # New shared function
+ upload_and_get_final_output_location # New shared function
+)
+from ..wgp_utils import generate_single_video
+
+
+def _handle_single_image_task(wgp_mod, task_params_from_db: dict, main_output_dir_base: Path, task_id: str, image_download_dir: Path | str | None = None, apply_reward_lora: bool = False, *, dprint):
+ """
+ Handles single image generation tasks.
+
+ Args:
+ wgp_mod: The WGP module for generation
+ task_params_from_db: Task parameters from the database
+ main_output_dir_base: Base output directory
+ task_id: Task ID for logging
+ image_download_dir: Directory for downloading images if URLs are provided
+ apply_reward_lora: Whether to apply reward LoRA
+ dprint: Debug print function
+
+ Returns:
+ Tuple[bool, str]: (success, output_location_or_error_message)
+ """
+ dprint(f"_handle_single_image_task: Starting for {task_id}")
+ dprint(f"Single image task params: {json.dumps(task_params_from_db, default=str, indent=2)}")
+
+ try:
+ # -------------------------------------------------------------
+ # Flatten orchestrator_details (if present) so that nested keys
+ # like `use_causvid_lora` or `prompt` become first-class entries.
+ # Top-level keys take precedence over nested ones in case of clash.
+ # -------------------------------------------------------------
+ if isinstance(task_params_from_db.get("orchestrator_details"), dict):
+ task_params_from_db = {
+ **task_params_from_db["orchestrator_details"], # Nested first
+ **{k: v for k, v in task_params_from_db.items() if k != "orchestrator_details"}, # Top-level override
+ }
+
+ # Extract required parameters with defaults
+ prompt = task_params_from_db.get("prompt", " ").strip() or " " # Default to space if empty
+ model_name = task_params_from_db.get("model", "t2v") # Default to t2v model
+ resolution = task_params_from_db.get("resolution", "832x480")
+ seed = task_params_from_db.get("seed", -1)
+ negative_prompt = task_params_from_db.get("negative_prompt", "").strip() or " " # Default to space if empty
+
+ # Validate and parse resolution with model grid snapping
+ try:
+ parsed_res = sm_parse_resolution(resolution)
+ if parsed_res is None:
+ raise ValueError(f"Invalid resolution format: {resolution}")
+ width, height = snap_resolution_to_model_grid(parsed_res)
+ resolution = f"{width}x{height}"
+ dprint(f"Single image task {task_id}: Adjusted resolution to {resolution}")
+ except Exception as e_res:
+ error_msg = f"Single image task {task_id}: Resolution parsing failed: {e_res}"
+ print(f"[ERROR] {error_msg}")
+ return False, error_msg
+
+ # Handle reference images if provided
+ image_refs_paths = []
+ if task_params_from_db.get("image_refs_paths"):
+ try:
+ loaded_refs = sm_load_pil_images(
+ task_params_from_db["image_refs_paths"],
+ wgp_mod.convert_image,
+ image_download_dir,
+ task_id,
+ dprint
+ )
+ if loaded_refs:
+ # Convert back to paths for the payload
+ image_refs_paths = task_params_from_db["image_refs_paths"]
+ except Exception as e_refs:
+ dprint(f"Single image task {task_id}: Warning - failed to load reference images: {e_refs}")
+
+ # Determine model filename for LoRA handling
+ model_filename_for_task = wgp_mod.get_model_filename(
+ model_name,
+ wgp_mod.transformer_quantization,
+ wgp_mod.transformer_dtype_policy
+ )
+
+ # Handle additional LoRAs using shared function
+ processed_additional_loras = {}
+ additional_loras = task_params_from_db.get("additional_loras", {})
+ if additional_loras:
+ dprint(f"Single image task {task_id}: Processing additional LoRAs: {additional_loras}")
+ processed_additional_loras = process_additional_loras_shared(
+ additional_loras,
+ wgp_mod,
+ model_filename_for_task,
+ task_id,
+ dprint
+ )
+
+ # Prepare the output path (with Supabase upload support)
+ output_filename = f"single_image_{task_id}.png"
+ local_output_path, initial_db_output_location = prepare_output_path_with_upload(
+ task_id,
+ output_filename,
+ main_output_dir_base,
+ dprint=dprint
+ )
+
+ # Create a temporary directory for WGP processing
+ with tempfile.TemporaryDirectory(prefix=f"single_img_{task_id}_") as temp_dir:
+ temp_dir_path = Path(temp_dir)
+ temp_video_path = temp_dir_path / f"{task_id}_temp.mp4"
+
+ dprint(f"Single image task {task_id}: Using temp directory {temp_dir_path}")
+
+ # Set up WGP module state temporarily
+ original_save_path = wgp_mod.save_path
+ wgp_mod.save_path = str(temp_dir_path)
+
+ try:
+ dprint(f"Single image task {task_id}: Calling generate_single_video with new flexible API")
+
+ # Use the new flexible keyword-style API
+ use_causvid = task_params_from_db.get("use_causvid_lora", False)
+ use_lighti2x = task_params_from_db.get("use_lighti2x_lora", False)
+
+ num_inference_steps = (
+ task_params_from_db.get("steps")
+ or task_params_from_db.get("num_inference_steps")
+ or (8 if use_causvid else (5 if use_lighti2x else 30))
+ )
+
+ if use_causvid:
+ default_guidance = 1.0
+ default_flow_shift = 1.0
+ elif use_lighti2x:
+ default_guidance = 1.0
+ default_flow_shift = 5.0
+ else:
+ default_guidance = task_params_from_db.get("guidance_scale", 5.0)
+ default_flow_shift = task_params_from_db.get("flow_shift", 3.0)
+
+ generation_success, video_path_generated = generate_single_video(
+ wgp_mod=wgp_mod,
+ task_id=f"{task_id}_wgp_internal",
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ resolution=resolution,
+ video_length=1, # Single frame
+ seed=seed,
+ model_filename=model_filename_for_task,
+ use_causvid_lora=use_causvid,
+ use_lighti2x_lora=use_lighti2x,
+ apply_reward_lora=apply_reward_lora,
+ additional_loras=processed_additional_loras,
+ image_refs=image_refs_paths,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=default_guidance,
+ flow_shift=default_flow_shift,
+ cfg_star_switch=task_params_from_db.get("cfg_star_switch", 0),
+ cfg_zero_step=task_params_from_db.get("cfg_zero_step", -1),
+ prompt_enhancer=task_params_from_db.get("prompt_enhancer_mode", ""),
+ dprint=dprint
+ )
+
+ if not generation_success or not video_path_generated:
+ error_msg = f"Single image task {task_id}: WGP generation failed."
+ print(f"[ERROR] {error_msg}")
+ return False, error_msg
+
+ # Convert video path to Path object if it's a string
+ video_path_obj = Path(video_path_generated)
+ if not video_path_obj.exists():
+ error_msg = f"Single image task {task_id}: WGP generation failed - no output video found at {video_path_generated}"
+ print(f"[ERROR] {error_msg}")
+ return False, error_msg
+
+ print(f"[Single Image {task_id}] Extracting frame from generated video...")
+
+ # Extract the first (and only) frame using cv2
+ import cv2
+ cap = cv2.VideoCapture(str(video_path_obj))
+ try:
+ if cap.isOpened():
+ ret, frame = cap.read()
+ if ret:
+ # Save the frame as PNG to the final location
+ success = cv2.imwrite(str(local_output_path), frame)
+ if success and local_output_path.exists():
+ print(f"[Single Image {task_id}] Successfully saved image to: {local_output_path}")
+
+ # Handle Supabase upload (if configured) and get final location for DB
+ final_db_location = upload_and_get_final_output_location(
+ local_file_path=local_output_path,
+ supabase_object_name=output_filename, # Pass only the filename
+ initial_db_location=initial_db_output_location,
+ dprint=dprint
+ )
+
+ return True, final_db_location
+ else:
+ error_msg = f"Single image task {task_id}: Failed to save extracted frame to {local_output_path}"
+ print(f"[ERROR] {error_msg}")
+ return False, error_msg
+ else:
+ error_msg = f"Single image task {task_id}: Failed to read frame from generated video"
+ print(f"[ERROR] {error_msg}")
+ return False, error_msg
+ else:
+ error_msg = f"Single image task {task_id}: Failed to open generated video {video_path_obj}"
+ print(f"[ERROR] {error_msg}")
+ return False, error_msg
+ finally:
+ cap.release()
+
+ finally:
+ # Restore original WGP save path
+ wgp_mod.save_path = original_save_path
+
+ except Exception as e:
+ error_msg = f"Single image task {task_id}: Unexpected error: {e}"
+ print(f"[ERROR] {error_msg}")
+ traceback.print_exc()
+ return False, error_msg
\ No newline at end of file
diff --git a/source/sm_functions/travel_between_images.py b/source/sm_functions/travel_between_images.py
new file mode 100644
index 000000000..5ec38ea16
--- /dev/null
+++ b/source/sm_functions/travel_between_images.py
@@ -0,0 +1,1927 @@
+import json
+import math
+import shutil
+import traceback
+from pathlib import Path
+import time
+import subprocess
+
+try:
+ import cv2
+ import numpy as np
+ _COLOR_MATCH_DEPS_AVAILABLE = True
+except ImportError:
+ _COLOR_MATCH_DEPS_AVAILABLE = False
+
+# --- SM_RESTRUCTURE: Import moved/new utilities ---
+from .. import db_operations as db_ops
+from ..common_utils import (
+ generate_unique_task_id as sm_generate_unique_task_id,
+ get_video_frame_count_and_fps as sm_get_video_frame_count_and_fps,
+ sm_get_unique_target_path,
+ create_color_frame as sm_create_color_frame,
+ parse_resolution as sm_parse_resolution,
+ download_image_if_url as sm_download_image_if_url,
+ prepare_output_path,
+ prepare_output_path_with_upload,
+ upload_and_get_final_output_location,
+ snap_resolution_to_model_grid,
+ ensure_valid_prompt,
+ ensure_valid_negative_prompt,
+ process_additional_loras_shared,
+ wait_for_file_stable as sm_wait_for_file_stable,
+)
+from ..video_utils import (
+ extract_frames_from_video as sm_extract_frames_from_video,
+ create_video_from_frames_list as sm_create_video_from_frames_list,
+ cross_fade_overlap_frames as sm_cross_fade_overlap_frames,
+ _apply_saturation_to_video_ffmpeg as sm_apply_saturation_to_video_ffmpeg,
+ apply_brightness_to_video_frames,
+ prepare_vace_ref_for_segment as sm_prepare_vace_ref_for_segment,
+ create_guide_video_for_travel_segment as sm_create_guide_video_for_travel_segment,
+ apply_color_matching_to_video as sm_apply_color_matching_to_video,
+ extract_last_frame_as_image as sm_extract_last_frame_as_image,
+ overlay_start_end_images_above_video as sm_overlay_start_end_images_above_video,
+)
+from ..wgp_utils import generate_single_video
+
+# Add debugging helper function
+def debug_video_analysis(video_path: str | Path, label: str, task_id: str = "unknown") -> dict:
+ """Analyze a video file and return comprehensive debug info"""
+ try:
+ path_obj = Path(video_path)
+ if not path_obj.exists():
+ print(f"[VIDEO_DEBUG] {label} ({task_id}): FILE MISSING - {video_path}")
+ return {"exists": False, "path": str(video_path)}
+
+ frame_count, fps = sm_get_video_frame_count_and_fps(str(path_obj))
+ file_size = path_obj.stat().st_size
+ duration = frame_count / fps if fps and fps > 0 else 0
+
+ debug_info = {
+ "exists": True,
+ "path": str(path_obj.resolve()),
+ "frame_count": frame_count,
+ "fps": fps,
+ "duration_seconds": duration,
+ "file_size_bytes": file_size,
+ "file_size_mb": round(file_size / (1024*1024), 2)
+ }
+
+ print(f"[VIDEO_DEBUG] {label} ({task_id}):")
+ print(f"[VIDEO_DEBUG] Path: {debug_info['path']}")
+ print(f"[VIDEO_DEBUG] Frames: {debug_info['frame_count']}")
+ print(f"[VIDEO_DEBUG] FPS: {debug_info['fps']}")
+ print(f"[VIDEO_DEBUG] Duration: {debug_info['duration_seconds']:.2f}s")
+ print(f"[VIDEO_DEBUG] Size: {debug_info['file_size_mb']} MB")
+
+ return debug_info
+
+ except Exception as e:
+ print(f"[VIDEO_DEBUG] {label} ({task_id}): ERROR analyzing video - {e}")
+ return {"exists": False, "error": str(e), "path": str(video_path)}
+
+# --- SM_RESTRUCTURE: New Handler Functions for Travel Tasks ---
+def _handle_travel_orchestrator_task(task_params_from_db: dict, main_output_dir_base: Path, orchestrator_task_id_str: str, orchestrator_project_id: str | None, *, dprint):
+ dprint(f"_handle_travel_orchestrator_task: Starting for {orchestrator_task_id_str}")
+ dprint(f"Orchestrator Project ID: {orchestrator_project_id}") # Added dprint
+ dprint(f"Orchestrator task_params_from_db (first 1000 chars): {json.dumps(task_params_from_db, default=str, indent=2)[:1000]}...")
+ generation_success = False # Represents success of orchestration step
+ output_message_for_orchestrator_db = f"Orchestration for {orchestrator_task_id_str} initiated."
+
+ try:
+ if 'orchestrator_details' not in task_params_from_db:
+ msg = f"[ERROR Task ID: {orchestrator_task_id_str}] 'orchestrator_details' not found in task_params_from_db."
+ print(msg)
+ return False, msg
+
+ orchestrator_payload = task_params_from_db['orchestrator_details']
+ dprint(f"Orchestrator payload for {orchestrator_task_id_str} (first 500 chars): {json.dumps(orchestrator_payload, indent=2, default=str)[:500]}...")
+
+ run_id = orchestrator_payload.get("run_id", orchestrator_task_id_str)
+ base_dir_for_this_run_str = orchestrator_payload.get("main_output_dir_for_run", str(main_output_dir_base.resolve()))
+
+ # Use the base directory directly without creating run-specific subdirectories
+ current_run_output_dir = Path(base_dir_for_this_run_str)
+ current_run_output_dir.mkdir(parents=True, exist_ok=True)
+ dprint(f"Orchestrator {orchestrator_task_id_str}: Base output directory for this run: {current_run_output_dir.resolve()}")
+
+ num_segments = orchestrator_payload.get("num_new_segments_to_generate", 0)
+ if num_segments <= 0:
+ msg = f"[WARNING Task ID: {orchestrator_task_id_str}] No new segments to generate based on orchestrator payload. Orchestration complete (vacuously)."
+ print(msg)
+ return True, msg
+
+ db_path_for_add = db_ops.SQLITE_DB_PATH if db_ops.DB_TYPE == "sqlite" else None
+ previous_segment_task_id = None
+
+ # --- Determine image download directory for this orchestrated run ---
+ segment_image_download_dir_str : str | None = None
+ if db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH: # SQLITE_DB_PATH is global
+ try:
+ sqlite_db_path_obj = Path(db_ops.SQLITE_DB_PATH).resolve()
+ if sqlite_db_path_obj.is_file():
+ sqlite_db_parent_dir = sqlite_db_path_obj.parent
+ # Orchestrated downloads go into a subfolder named after the run_id
+ candidate_download_dir = sqlite_db_parent_dir / "public" / "data" / "image_downloads_orchestrated" / run_id
+ candidate_download_dir.mkdir(parents=True, exist_ok=True)
+ segment_image_download_dir_str = str(candidate_download_dir.resolve())
+ dprint(f"Orchestrator {orchestrator_task_id_str}: Determined segment_image_download_dir for run {run_id}: {segment_image_download_dir_str}")
+ else:
+ dprint(f"Orchestrator {orchestrator_task_id_str}: SQLITE_DB_PATH '{db_ops.SQLITE_DB_PATH}' is not a file. Cannot determine parent for image_download_dir.")
+ except Exception as e_idir_orch:
+ dprint(f"Orchestrator {orchestrator_task_id_str}: Could not create image_download_dir for run {run_id}: {e_idir_orch}. Segments may not download URL images to specific dir.")
+ # Add similar logic for Supabase if a writable shared path convention exists.
+
+ # Expanded arrays from orchestrator payload
+ expanded_base_prompts = orchestrator_payload["base_prompts_expanded"]
+ expanded_negative_prompts = orchestrator_payload["negative_prompts_expanded"]
+ expanded_segment_frames = orchestrator_payload["segment_frames_expanded"]
+ expanded_frame_overlap = orchestrator_payload["frame_overlap_expanded"]
+ vace_refs_instructions_all = orchestrator_payload.get("vace_image_refs_to_prepare_by_headless", [])
+
+ # Preserve a copy of the original overlap list in case we need it later
+ _orig_frame_overlap = list(expanded_frame_overlap) # shallow copy
+
+ # --- SM_QUANTIZE_FRAMES_AND_OVERLAPS ---
+ # Adjust all segment lengths to match model constraints (4*N+1 format).
+ # Then, adjust overlap values to be even and not exceed the length of the
+ # smaller of the two segments they connect. This prevents errors downstream
+ # in guide video creation, generation, and stitching.
+
+ print(f"[FRAME_DEBUG] Orchestrator {orchestrator_task_id_str}: QUANTIZATION ANALYSIS")
+ print(f"[FRAME_DEBUG] Original segment_frames_expanded: {expanded_segment_frames}")
+ print(f"[FRAME_DEBUG] Original frame_overlap: {expanded_frame_overlap}")
+
+ quantized_segment_frames = []
+ dprint(f"Orchestrator: Quantizing frame counts. Original segment_frames_expanded: {expanded_segment_frames}")
+ for i, frames in enumerate(expanded_segment_frames):
+ # Quantize to 4*N+1 format to match model constraints, applied later in headless.py
+ new_frames = (frames // 4) * 4 + 1
+ print(f"[FRAME_DEBUG] Segment {i}: {frames} -> {new_frames} (4*N+1 quantization)")
+ if new_frames != frames:
+ dprint(f"Orchestrator: Quantized segment {i} length from {frames} to {new_frames} (4*N+1 format).")
+ quantized_segment_frames.append(new_frames)
+
+ print(f"[FRAME_DEBUG] Quantized segment_frames: {quantized_segment_frames}")
+ dprint(f"Orchestrator: Finished quantizing frame counts. New quantized_segment_frames: {quantized_segment_frames}")
+
+ quantized_frame_overlap = []
+ # There are N-1 overlaps for N segments. The loop must not iterate more times than this.
+ num_overlaps_to_process = len(quantized_segment_frames) - 1
+ print(f"[FRAME_DEBUG] Processing {num_overlaps_to_process} overlap values")
+
+ if num_overlaps_to_process > 0:
+ for i in range(num_overlaps_to_process):
+ # Gracefully handle if the original overlap array is longer than expected.
+ if i < len(expanded_frame_overlap):
+ original_overlap = expanded_frame_overlap[i]
+ else:
+ # This case should not happen if client is correct, but as a fallback.
+ dprint(f"Orchestrator: Overlap at index {i} missing. Defaulting to 0.")
+ original_overlap = 0
+
+ # Overlap connects segment i and i+1.
+ # It cannot be larger than the shorter of the two segments.
+ max_possible_overlap = min(quantized_segment_frames[i], quantized_segment_frames[i+1])
+
+ # Quantize original overlap to be even, then cap it.
+ new_overlap = (original_overlap // 2) * 2
+ new_overlap = min(new_overlap, max_possible_overlap)
+ if new_overlap < 0: new_overlap = 0
+
+ print(f"[FRAME_DEBUG] Overlap {i} (segments {i}->{i+1}): {original_overlap} -> {new_overlap}")
+ print(f"[FRAME_DEBUG] Segment lengths: {quantized_segment_frames[i]}, {quantized_segment_frames[i+1]}")
+ print(f"[FRAME_DEBUG] Max possible overlap: {max_possible_overlap}")
+
+ if new_overlap != original_overlap:
+ dprint(f"Orchestrator: Adjusted overlap between segments {i}-{i+1} from {original_overlap} to {new_overlap}.")
+
+ quantized_frame_overlap.append(new_overlap)
+
+ print(f"[FRAME_DEBUG] Final quantized_frame_overlap: {quantized_frame_overlap}")
+
+ # Persist quantised results back to orchestrator_payload so all downstream tasks see them
+ orchestrator_payload["segment_frames_expanded"] = quantized_segment_frames
+ orchestrator_payload["frame_overlap_expanded"] = quantized_frame_overlap
+
+ # Calculate expected final length
+ total_input_frames = sum(quantized_segment_frames)
+ total_overlaps = sum(quantized_frame_overlap)
+ expected_final_length = total_input_frames - total_overlaps
+ print(f"[FRAME_DEBUG] EXPECTED FINAL VIDEO:")
+ print(f"[FRAME_DEBUG] Total input frames: {total_input_frames}")
+ print(f"[FRAME_DEBUG] Total overlaps: {total_overlaps}")
+ print(f"[FRAME_DEBUG] Expected final length: {expected_final_length} frames")
+ print(f"[FRAME_DEBUG] Expected duration: {expected_final_length / orchestrator_payload.get('fps_helpers', 16):.2f}s")
+
+ # Replace original lists with the new quantized ones for all subsequent logic
+ expanded_segment_frames = quantized_segment_frames
+ expanded_frame_overlap = quantized_frame_overlap
+ # --- END SM_QUANTIZE_FRAMES_AND_OVERLAPS ---
+
+ # If quantisation resulted in an empty overlap list (e.g. single-segment run) but the
+ # original payload DID contain an overlap value, restore that so the first segment
+ # can still reuse frames from the previous/continued video. This is crucial for
+ # continue-video journeys where we expect `frame_overlap_from_previous` > 0.
+ if (not expanded_frame_overlap) and _orig_frame_overlap:
+ expanded_frame_overlap = _orig_frame_overlap
+
+ for idx in range(num_segments):
+ current_segment_task_id = sm_generate_unique_task_id(f"travel_seg_{run_id}_{idx:02d}_")
+
+ # Note: segment handler now manages its own output paths using prepare_output_path()
+
+ # Determine frame_overlap_from_previous for current segment `idx`
+ current_frame_overlap_from_previous = 0
+ if idx == 0 and orchestrator_payload.get("continue_from_video_resolved_path"):
+ current_frame_overlap_from_previous = expanded_frame_overlap[0] if expanded_frame_overlap else 0
+ elif idx > 0:
+ # SM_RESTRUCTURE_FIX_OVERLAP_IDX: Use idx-1 for subsequent segments
+ current_frame_overlap_from_previous = expanded_frame_overlap[idx-1] if len(expanded_frame_overlap) > (idx-1) else 0
+
+ # VACE refs for this specific segment
+ # Ensure vace_refs_instructions_all is a list, default to empty list if None
+ vace_refs_safe = vace_refs_instructions_all if vace_refs_instructions_all is not None else []
+ vace_refs_for_this_segment = [
+ ref_instr for ref_instr in vace_refs_safe
+ if ref_instr.get("segment_idx_for_naming") == idx
+ ]
+
+ segment_payload = {
+ "task_id": current_segment_task_id,
+ "orchestrator_task_id_ref": orchestrator_task_id_str,
+ "orchestrator_run_id": run_id,
+ "project_id": orchestrator_project_id, # Added project_id
+ "segment_index": idx,
+ "is_first_segment": (idx == 0),
+ "is_last_segment": (idx == num_segments - 1),
+
+ "current_run_base_output_dir": str(current_run_output_dir.resolve()), # Base for segment's own output folder creation
+
+ "base_prompt": expanded_base_prompts[idx],
+ "negative_prompt": expanded_negative_prompts[idx],
+ "segment_frames_target": expanded_segment_frames[idx],
+ "frame_overlap_from_previous": current_frame_overlap_from_previous,
+ "frame_overlap_with_next": expanded_frame_overlap[idx] if len(expanded_frame_overlap) > idx else 0,
+
+ "vace_image_refs_to_prepare_by_headless": vace_refs_for_this_segment, # Already filtered for this segment
+
+ "parsed_resolution_wh": orchestrator_payload["parsed_resolution_wh"],
+ "model_name": orchestrator_payload["model_name"],
+ "seed_to_use": orchestrator_payload.get("seed_base", 12345) + idx,
+ "use_causvid_lora": orchestrator_payload.get("apply_causvid", False),
+ "apply_reward_lora": orchestrator_payload.get("apply_reward_lora", False),
+ "cfg_star_switch": orchestrator_payload.get("cfg_star_switch", 0),
+ "cfg_zero_step": orchestrator_payload.get("cfg_zero_step", -1),
+ "params_json_str_override": orchestrator_payload.get("params_json_str_override"),
+ "fps_helpers": orchestrator_payload.get("fps_helpers", 16),
+ "fade_in_params_json_str": orchestrator_payload["fade_in_params_json_str"],
+ "fade_out_params_json_str": orchestrator_payload["fade_out_params_json_str"],
+ "subsequent_starting_strength_adjustment": orchestrator_payload.get("subsequent_starting_strength_adjustment", 0.0),
+ "desaturate_subsequent_starting_frames": orchestrator_payload.get("desaturate_subsequent_starting_frames", 0.0),
+ "adjust_brightness_subsequent_starting_frames": orchestrator_payload.get("adjust_brightness_subsequent_starting_frames", 0.0),
+ "after_first_post_generation_saturation": orchestrator_payload.get("after_first_post_generation_saturation"),
+ "after_first_post_generation_brightness": orchestrator_payload.get("after_first_post_generation_brightness"),
+
+ "segment_image_download_dir": segment_image_download_dir_str, # Add the download dir path string
+
+ "debug_mode_enabled": orchestrator_payload.get("debug_mode_enabled", False),
+ "skip_cleanup_enabled": orchestrator_payload.get("skip_cleanup_enabled", False),
+ "continue_from_video_resolved_path_for_guide": orchestrator_payload.get("continue_from_video_resolved_path") if idx == 0 else None,
+ "full_orchestrator_payload": orchestrator_payload, # Ensure full payload is passed to segment
+ }
+
+ dprint(f"Orchestrator: Enqueuing travel_segment {idx} (ID: {current_segment_task_id}) depends_on={previous_segment_task_id}")
+ db_ops.add_task_to_db(
+ task_payload=segment_payload,
+ task_type_str="travel_segment",
+ dependant_on=previous_segment_task_id
+ )
+ previous_segment_task_id = current_segment_task_id
+ dprint(f"Orchestrator {orchestrator_task_id_str}: Enqueued travel_segment {idx} (ID: {current_segment_task_id}) with payload (first 500 chars): {json.dumps(segment_payload, default=str)[:500]}... Depends on: {previous_segment_task_id}")
+
+ # After loop, enqueue the stitch task
+ stitch_task_id = sm_generate_unique_task_id(f"travel_stitch_{run_id}_")
+ final_stitched_video_name = f"travel_final_stitched_{run_id}.mp4"
+ # Stitcher saves its final primary output directly under main_output_dir (e.g., ./steerable_motion_output/)
+ # NOT under current_run_output_dir (which is .../travel_run_XYZ/)
+ # The main_output_dir_base is the one passed to headless.py (e.g. server's ./outputs or steerable_motion's ./steerable_motion_output)
+ # The orchestrator_payload["main_output_dir_for_run"] is this main_output_dir_base.
+ final_stitched_output_path = Path(orchestrator_payload.get("main_output_dir_for_run", str(main_output_dir_base.resolve()))) / final_stitched_video_name
+
+ stitch_payload = {
+ "task_id": stitch_task_id,
+ "orchestrator_task_id_ref": orchestrator_task_id_str,
+ "orchestrator_run_id": run_id,
+ "project_id": orchestrator_project_id, # Added project_id
+ "num_total_segments_generated": num_segments,
+ "current_run_base_output_dir": str(current_run_output_dir.resolve()), # Stitcher needs this to find segment outputs
+ "frame_overlap_settings_expanded": expanded_frame_overlap,
+ "crossfade_sharp_amt": orchestrator_payload.get("crossfade_sharp_amt", 0.3),
+ "parsed_resolution_wh": orchestrator_payload["parsed_resolution_wh"],
+ "fps_final_video": orchestrator_payload.get("fps_helpers", 16),
+ "upscale_factor": orchestrator_payload.get("upscale_factor", 0.0),
+ "upscale_model_name": orchestrator_payload.get("upscale_model_name"),
+ "seed_for_upscale": orchestrator_payload.get("seed_base", 12345) + 5000, # Consistent seed for upscale
+ "debug_mode_enabled": orchestrator_payload.get("debug_mode_enabled", False),
+ "skip_cleanup_enabled": orchestrator_payload.get("skip_cleanup_enabled", False),
+ "initial_continued_video_path": orchestrator_payload.get("continue_from_video_resolved_path"),
+ "final_stitched_output_path": str(final_stitched_output_path.resolve()),
+ # For upscale polling, if stitcher enqueues an upscale sub-task
+ "poll_interval_from_orchestrator": orchestrator_payload.get("original_common_args", {}).get("poll_interval", 15),
+ "poll_timeout_from_orchestrator": orchestrator_payload.get("original_common_args", {}).get("poll_timeout", 1800),
+ "full_orchestrator_payload": orchestrator_payload, # Added this line
+ }
+
+ dprint(f"Orchestrator: Enqueuing travel_stitch task (ID: {stitch_task_id}) depends_on={previous_segment_task_id}")
+ db_ops.add_task_to_db(
+ task_payload=stitch_payload,
+ task_type_str="travel_stitch",
+ dependant_on=previous_segment_task_id
+ )
+ dprint(f"Orchestrator {orchestrator_task_id_str}: Enqueued travel_stitch task (ID: {stitch_task_id}) with payload (first 500 chars): {json.dumps(stitch_payload, default=str)[:500]}... Depends on: {previous_segment_task_id}")
+
+ generation_success = True
+ output_message_for_orchestrator_db = f"Successfully enqueued all {num_segments} segment tasks and 1 stitch task for run {run_id}."
+ print(f"Orchestrator {orchestrator_task_id_str}: {output_message_for_orchestrator_db}")
+
+ except Exception as e:
+ msg = f"[ERROR Task ID: {orchestrator_task_id_str}] Failed during travel orchestration processing: {e}"
+ print(msg)
+ traceback.print_exc()
+ generation_success = False
+ output_message_for_orchestrator_db = msg
+
+ return generation_success, output_message_for_orchestrator_db
+
+def _handle_travel_segment_task(wgp_mod, task_params_from_db: dict, main_output_dir_base: Path, segment_task_id_str: str, apply_reward_lora: bool = False, colour_match_videos: bool = False, mask_active_frames: bool = True, *, process_single_task, dprint):
+ dprint(f"_handle_travel_segment_task: Starting for {segment_task_id_str}")
+ dprint(f"Segment task_params_from_db (first 1000 chars): {json.dumps(task_params_from_db, default=str, indent=2)[:1000]}...")
+ # task_params_from_db contains what was enqueued for this specific segment,
+ # including potentially 'full_orchestrator_payload'.
+ segment_params = task_params_from_db
+ generation_success = False # Success of the WGP/Comfy sub-task for this segment
+ final_segment_video_output_path_str = None # Output of the WGP sub-task
+ output_message_for_segment_task = "Segment task initiated."
+
+ try:
+ # --- 1. Initialization & Parameter Extraction ---
+ orchestrator_task_id_ref = segment_params.get("orchestrator_task_id_ref")
+ orchestrator_run_id = segment_params.get("orchestrator_run_id")
+ segment_idx = segment_params.get("segment_index")
+ segment_image_download_dir_str = segment_params.get("segment_image_download_dir") # Get the passed dir
+ segment_image_download_dir = Path(segment_image_download_dir_str) if segment_image_download_dir_str else None
+
+ if orchestrator_task_id_ref is None or orchestrator_run_id is None or segment_idx is None:
+ msg = f"Segment task {segment_task_id_str} missing critical orchestrator refs or segment_index."
+ print(f"[ERROR Task {segment_task_id_str}]: {msg}")
+ return False, msg
+
+ full_orchestrator_payload = segment_params.get("full_orchestrator_payload")
+ if not full_orchestrator_payload:
+ dprint(f"Segment {segment_idx}: full_orchestrator_payload not in direct params. Querying orchestrator task {orchestrator_task_id_ref}")
+ # This logic now uses the db_ops functions
+ orchestrator_task_raw_params_json = db_ops.get_task_params(orchestrator_task_id_ref)
+
+ if orchestrator_task_raw_params_json:
+ try:
+ fetched_params = json.loads(orchestrator_task_raw_params_json) if isinstance(orchestrator_task_raw_params_json, str) else orchestrator_task_raw_params_json
+ full_orchestrator_payload = fetched_params.get("orchestrator_details")
+ if not full_orchestrator_payload:
+ raise ValueError("'orchestrator_details' key missing in fetched orchestrator task params.")
+ dprint(f"Segment {segment_idx}: Successfully fetched orchestrator_details from DB.")
+ except Exception as e_fetch_orc:
+ msg = f"Segment {segment_idx}: Failed to fetch/parse orchestrator_details from DB for task {orchestrator_task_id_ref}: {e_fetch_orc}"
+ print(f"[ERROR Task {segment_task_id_str}]: {msg}")
+ return False, msg
+ else:
+ msg = f"Segment {segment_idx}: Could not retrieve params for orchestrator task {orchestrator_task_id_ref}. Cannot proceed."
+ print(f"[ERROR Task {segment_task_id_str}]: {msg}")
+ return False, msg
+
+ # Now full_orchestrator_payload is guaranteed to be populated or we've exited.
+ # FIX: Prioritize job-specific settings from orchestrator payload over server-wide CLI flags.
+ effective_colour_match_enabled = full_orchestrator_payload.get("colour_match_videos", colour_match_videos)
+ effective_apply_reward_lora = full_orchestrator_payload.get("apply_reward_lora", apply_reward_lora)
+
+ additional_loras = full_orchestrator_payload.get("additional_loras", {})
+ if additional_loras:
+ dprint(f"Segment {segment_idx}: Found additional_loras in orchestrator payload: {additional_loras}")
+
+ current_run_base_output_dir_str = segment_params.get("current_run_base_output_dir")
+ if not current_run_base_output_dir_str: # Should be passed by orchestrator/prev segment
+ current_run_base_output_dir_str = full_orchestrator_payload.get("main_output_dir_for_run", str(main_output_dir_base.resolve()))
+ current_run_base_output_dir_str = str(Path(current_run_base_output_dir_str) / f"travel_run_{orchestrator_run_id}")
+
+ current_run_base_output_dir = Path(current_run_base_output_dir_str)
+ # Use the base directory directly without creating segment-specific subdirectories
+ segment_processing_dir = current_run_base_output_dir
+ segment_processing_dir.mkdir(parents=True, exist_ok=True)
+
+ # ─── Ensure we have a directory for downloading remote images ────────────
+ if segment_image_download_dir is None:
+ segment_image_download_dir = segment_processing_dir / "image_downloads"
+ try:
+ segment_image_download_dir.mkdir(parents=True, exist_ok=True)
+ except Exception as e_mkdir_dl:
+ dprint(f"[WARNING] Segment {segment_idx}: Could not create image_downloads dir {segment_image_download_dir}: {e_mkdir_dl}")
+ dprint(f"Segment {segment_idx} (Task {segment_task_id_str}): Processing in {segment_processing_dir.resolve()} | image_download_dir={segment_image_download_dir}")
+
+ # --- Color Match Reference Image Determination ---
+ start_ref_path_for_cm, end_ref_path_for_cm = None, None
+ if effective_colour_match_enabled:
+ input_images_for_cm = full_orchestrator_payload.get("input_image_paths_resolved", [])
+ is_continuing_for_cm = full_orchestrator_payload.get("continue_from_video_resolved_path") is not None
+
+ if is_continuing_for_cm:
+ if segment_idx == 0:
+ continued_video_path = full_orchestrator_payload.get("continue_from_video_resolved_path")
+ if continued_video_path and Path(continued_video_path).exists():
+ dprint(f"Seg {segment_idx} CM: Extracting last frame from {continued_video_path} as start ref.")
+ start_ref_path_for_cm = sm_extract_last_frame_as_image(continued_video_path, segment_processing_dir, segment_task_id_str)
+ if input_images_for_cm:
+ end_ref_path_for_cm = input_images_for_cm[0]
+ else: # Subsequent segment when continuing
+ if len(input_images_for_cm) > segment_idx:
+ start_ref_path_for_cm = input_images_for_cm[segment_idx - 1]
+ end_ref_path_for_cm = input_images_for_cm[segment_idx]
+ else: # From scratch
+ if len(input_images_for_cm) > segment_idx + 1:
+ start_ref_path_for_cm = input_images_for_cm[segment_idx]
+ end_ref_path_for_cm = input_images_for_cm[segment_idx + 1]
+
+ # Download images if they are URLs so they exist locally for the color matching function.
+ if start_ref_path_for_cm:
+ start_ref_path_for_cm = sm_download_image_if_url(start_ref_path_for_cm, segment_processing_dir, segment_task_id_str)
+ if end_ref_path_for_cm:
+ end_ref_path_for_cm = sm_download_image_if_url(end_ref_path_for_cm, segment_processing_dir, segment_task_id_str)
+
+ dprint(f"Seg {segment_idx} CM Refs: Start='{start_ref_path_for_cm}', End='{end_ref_path_for_cm}'")
+ # --- End Color Match Reference Image Determination ---
+
+ # --- Prepare VACE Refs for this Segment (moved to headless) ---
+ actual_vace_image_ref_paths_for_wgp = []
+ # Get the list of VACE ref instructions from the full orchestrator payload
+ vace_ref_instructions_from_orchestrator = full_orchestrator_payload.get("", [])
+
+ # Ensure vace_ref_instructions_from_orchestrator is a list, default to empty list if None
+ if vace_ref_instructions_from_orchestrator is None:
+ vace_ref_instructions_from_orchestrator = []
+
+ # Filter instructions for the current segment_idx
+ # The segment_idx_for_naming in the instruction should match the current segment_idx
+ relevant_vace_instructions = [
+ instr for instr in vace_ref_instructions_from_orchestrator
+ if instr.get("segment_idx_for_naming") == segment_idx
+ ]
+ dprint(f"Segment {segment_idx}: Found {len(relevant_vace_instructions)} VACE ref instructions relevant to this segment.")
+
+ if relevant_vace_instructions:
+ # Ensure parsed_res_wh is available
+ current_parsed_res_wh = full_orchestrator_payload.get("parsed_resolution_wh")
+ # If resolution is provided as string (e.g., "512x512"), convert it to tuple[int, int]
+ if isinstance(current_parsed_res_wh, str):
+ try:
+ parsed_tuple = sm_parse_resolution(current_parsed_res_wh)
+ if parsed_tuple is not None:
+ current_parsed_res_wh = parsed_tuple
+ else:
+ dprint(f"[WARNING] Segment {segment_idx}: Failed to parse resolution string '{current_parsed_res_wh}'. Proceeding with original value.")
+ except Exception as e_par:
+ dprint(f"[WARNING] Segment {segment_idx}: Error parsing resolution '{current_parsed_res_wh}': {e_par}. Proceeding with string value (may cause errors).")
+ if not current_parsed_res_wh:
+ # Fallback or error if resolution not found; for now, dprint and proceed (helper might handle None resolution)
+ dprint(f"[WARNING] Segment {segment_idx}: parsed_resolution_wh not found in full_orchestrator_payload. VACE refs might not be resized correctly.")
+
+ for ref_instr in relevant_vace_instructions:
+ # Pass segment_image_download_dir to _prepare_vace_ref_for_segment_headless
+ dprint(f"Segment {segment_idx}: Preparing VACE ref from instruction: {ref_instr}")
+ processed_ref_path = sm_prepare_vace_ref_for_segment(
+ ref_instruction=ref_instr,
+ segment_processing_dir=segment_processing_dir,
+ target_resolution_wh=current_parsed_res_wh,
+ image_download_dir=segment_image_download_dir, # Pass it here
+ task_id_for_logging=segment_task_id_str
+ )
+ if processed_ref_path:
+ actual_vace_image_ref_paths_for_wgp.append(processed_ref_path)
+ dprint(f"Segment {segment_idx}: Successfully prepared VACE ref: {processed_ref_path}")
+ else:
+ dprint(f"Segment {segment_idx}: Failed to prepare VACE ref from instruction: {ref_instr}. It will be omitted.")
+ # --- End VACE Ref Preparation ---
+
+ # --- 2. Guide Video Preparation ---
+ actual_guide_video_path_for_wgp: Path | None = None
+ path_to_previous_segment_video_output_for_guide: str | None = None
+
+ is_first_segment = segment_params.get("is_first_segment", segment_idx == 0) # is_first_segment should be reliable
+ is_first_segment_from_scratch = is_first_segment and not full_orchestrator_payload.get("continue_from_video_resolved_path")
+ is_first_new_segment_after_continue = is_first_segment and full_orchestrator_payload.get("continue_from_video_resolved_path")
+ is_subsequent_segment = not is_first_segment
+
+ # Ensure parsed_res_wh is a tuple of integers with model grid snapping
+ parsed_res_wh_str = full_orchestrator_payload["parsed_resolution_wh"]
+ try:
+ parsed_res_raw = sm_parse_resolution(parsed_res_wh_str)
+ if parsed_res_raw is None:
+ raise ValueError(f"sm_parse_resolution returned None for input: {parsed_res_wh_str}")
+ parsed_res_wh = snap_resolution_to_model_grid(parsed_res_raw)
+ except Exception as e_parse_res:
+ msg = f"Seg {segment_idx}: Invalid format or error parsing parsed_resolution_wh '{parsed_res_wh_str}': {e_parse_res}"
+ print(f"[ERROR Task {segment_task_id_str}]: {msg}"); return False, msg
+ dprint(f"Segment {segment_idx}: Parsed resolution (w,h): {parsed_res_wh}")
+
+ # --- Single Image Journey Detection ---
+ input_images_for_cm_check = full_orchestrator_payload.get("input_image_paths_resolved", [])
+ is_single_image_journey = (
+ len(input_images_for_cm_check) == 1
+ and full_orchestrator_payload.get("continue_from_video_resolved_path") is None
+ and segment_params.get("is_first_segment")
+ and segment_params.get("is_last_segment")
+ )
+ if is_single_image_journey:
+ dprint(f"Seg {segment_idx}: Detected a single-image journey. Adjusting guide and mask generation.")
+
+ # Calculate total frames for this segment once and reuse
+ base_duration = segment_params.get("segment_frames_target", full_orchestrator_payload["segment_frames_expanded"][segment_idx])
+ frame_overlap_from_previous = segment_params.get("frame_overlap_from_previous", 0)
+ # The user-facing 'segment_frames_target' should represent the total length of the segment,
+ # not just the new content. The overlap is handled internally for transition.
+ total_frames_for_segment = base_duration
+
+ print(f"[SEGMENT_DEBUG] Segment {segment_idx} (Task {segment_task_id_str}): FRAME ANALYSIS")
+ print(f"[SEGMENT_DEBUG] base_duration (segment_frames_target): {base_duration}")
+ print(f"[SEGMENT_DEBUG] frame_overlap_from_previous: {frame_overlap_from_previous}")
+ print(f"[SEGMENT_DEBUG] total_frames_for_segment: {total_frames_for_segment}")
+ print(f"[SEGMENT_DEBUG] is_first_segment: {segment_params.get('is_first_segment', False)}")
+ print(f"[SEGMENT_DEBUG] is_last_segment: {segment_params.get('is_last_segment', False)}")
+ print(f"[SEGMENT_DEBUG] use_causvid_lora: {full_orchestrator_payload.get('apply_causvid', False)}")
+
+ fps_helpers = full_orchestrator_payload.get("fps_helpers", 16)
+ fade_in_duration_str = full_orchestrator_payload["fade_in_params_json_str"]
+ fade_out_duration_str = full_orchestrator_payload["fade_out_params_json_str"]
+
+ # Define gray_frame_bgr here for use in subsequent segment strength adjustment
+ gray_frame_bgr = sm_create_color_frame(parsed_res_wh, (128, 128, 128))
+
+ # -------------------------------------------------------------
+ # Generate mask video for active/inactive frames if enabled
+ # -------------------------------------------------------------
+ mask_video_path_for_wgp: Path | None = None # default
+ if mask_active_frames:
+ try:
+ # --- Determine which frame indices should be kept (inactive = black) ---
+ inactive_indices: set[int] = set()
+
+ # Define overlap_count up front for consistent logging
+ overlap_count = max(0, int(frame_overlap_from_previous))
+
+ if is_single_image_journey:
+ # For a single image journey, only the first frame is kept from the guide.
+ inactive_indices.add(0)
+ dprint(f"Seg {segment_idx} Mask: Single image journey - keeping only frame 0 inactive.")
+ else:
+ # 1) Frames reused from the previous segment (overlap)
+ inactive_indices.update(range(overlap_count))
+
+ # 2) First frame when this is the very first segment from scratch
+ is_first_segment_val = segment_params.get("is_first_segment", False)
+ is_continue_scenario = full_orchestrator_payload.get("continue_from_video_resolved_path") is not None
+ if is_first_segment_val and not is_continue_scenario:
+ inactive_indices.add(0)
+
+ # 3) Last frame for ALL segments - each segment travels TO a target image
+ # Every segment ends at its target image, which should be kept (inactive/black)
+ inactive_indices.add(total_frames_for_segment - 1)
+
+ # --- NEW DEBUG LOGGING FOR MASK/OVERLAP DETAILS ---
+ print(f"[MASK_DEBUG] Segment {segment_idx}: frame_overlap_from_previous={frame_overlap_from_previous}")
+ print(f"[MASK_DEBUG] Segment {segment_idx}: inactive (masked) frame indices: {sorted(list(inactive_indices))}")
+ print(f"[MASK_DEBUG] Segment {segment_idx}: active (unmasked) frame indices: {[i for i in range(total_frames_for_segment) if i not in inactive_indices]}")
+ # --- END DEBUG LOGGING ---
+
+ # Debug: Show the conditions that determined inactive indices
+ dprint(
+ f"Seg {segment_idx}: Mask conditions - is_first_segment={segment_params.get('is_first_segment', False)}, "
+ f"is_continue_scenario={full_orchestrator_payload.get('continue_from_video_resolved_path') is not None}, is_last_segment={segment_params.get('is_last_segment', False)}, "
+ f"overlap_count={frame_overlap_from_previous}, is_single_image_journey={is_single_image_journey}"
+ )
+
+ mask_filename = f"{orchestrator_run_id}_seg{segment_idx:02d}_mask.mp4"
+ # Create mask video in the same directory as guide video for consistency
+ mask_out_path_tmp = segment_processing_dir / mask_filename
+
+ # Use the generalized mask creation function
+ from ..common_utils import create_mask_video_from_inactive_indices
+ created_mask_vid = create_mask_video_from_inactive_indices(
+ total_frames=total_frames_for_segment,
+ resolution_wh=parsed_res_wh,
+ inactive_frame_indices=inactive_indices,
+ output_path=mask_out_path_tmp,
+ fps=fps_helpers,
+ task_id_for_logging=segment_task_id_str,
+ dprint=dprint
+ )
+
+ if created_mask_vid:
+ mask_video_path_for_wgp = created_mask_vid
+ dprint(f"Seg {segment_idx}: mask video generated at {mask_video_path_for_wgp}")
+ else:
+ dprint(f"[WARNING] Seg {segment_idx}: Failed to generate mask video.")
+ except Exception as e_mask_gen2:
+ dprint(f"[WARNING] Seg {segment_idx}: Mask video generation error: {e_mask_gen2}")
+
+ try: # Parsing fade params
+ fade_in_p = json.loads(fade_in_duration_str)
+ fi_low, fi_high, fi_curve, fi_factor = float(fade_in_p.get("low_point",0)), float(fade_in_p.get("high_point",1)), str(fade_in_p.get("curve_type","ease_in_out")), float(fade_in_p.get("duration_factor",0))
+ except Exception as e_fade_in:
+ fi_low, fi_high, fi_curve, fi_factor = 0.0,1.0,"ease_in_out",0.0
+ dprint(f"Seg {segment_idx} Warn: Using default fade-in params due to parse error on '{fade_in_duration_str}': {e_fade_in}")
+ try:
+ fade_out_p = json.loads(fade_out_duration_str)
+ fo_low, fo_high, fo_curve, fo_factor = float(fade_out_p.get("low_point",0)), float(fade_out_p.get("high_point",1)), str(fade_out_p.get("curve_type","ease_in_out")), float(fade_out_p.get("duration_factor",0))
+ except Exception as e_fade_out:
+ fo_low, fo_high, fo_curve, fo_factor = 0.0,1.0,"ease_in_out",0.0
+ dprint(f"Seg {segment_idx} Warn: Using default fade-out params due to parse error on '{fade_out_duration_str}': {e_fade_out}")
+
+ if is_first_new_segment_after_continue:
+ path_to_previous_segment_video_output_for_guide = full_orchestrator_payload.get("continue_from_video_resolved_path")
+ if not path_to_previous_segment_video_output_for_guide or not Path(path_to_previous_segment_video_output_for_guide).exists():
+ msg = f"Seg {segment_idx}: Continue video path {path_to_previous_segment_video_output_for_guide} invalid."
+ print(f"[ERROR Task {segment_task_id_str}]: {msg}"); return False, msg
+ elif is_subsequent_segment:
+ # Get predecessor task ID and its output location in a single call using Edge Function (or fallback for SQLite)
+ task_dependency_id, raw_path_from_db = db_ops.get_predecessor_output_via_edge_function(segment_task_id_str)
+
+ if task_dependency_id and raw_path_from_db:
+ dprint(f"Seg {segment_idx}: Task {segment_task_id_str} depends on {task_dependency_id} with output: {raw_path_from_db}")
+ # path_to_previous_segment_video_output_for_guide will be relative ("files/...") if from SQLite and stored that way
+ # or absolute if from Supabase or stored absolutely in SQLite.
+ if db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH and raw_path_from_db.startswith("files/"):
+ sqlite_db_parent = Path(db_ops.SQLITE_DB_PATH).resolve().parent
+ path_to_previous_segment_video_output_for_guide = str((sqlite_db_parent / "public" / raw_path_from_db).resolve())
+ dprint(f"Seg {segment_idx}: Resolved SQLite relative path from DB '{raw_path_from_db}' to absolute path '{path_to_previous_segment_video_output_for_guide}'")
+ else:
+ # Path from DB is already absolute (Supabase) or an old absolute SQLite path
+ path_to_previous_segment_video_output_for_guide = raw_path_from_db
+ elif task_dependency_id and not raw_path_from_db:
+ dprint(f"Seg {segment_idx}: Found dependency task {task_dependency_id} but no output_location available.")
+ path_to_previous_segment_video_output_for_guide = None
+ else:
+ dprint(f"Seg {segment_idx}: No dependency found for task {segment_task_id_str}. Cannot create guide video based on predecessor.")
+ path_to_previous_segment_video_output_for_guide = None
+
+ # --- New: Handle Supabase public URLs by downloading them locally for guide processing ---
+ if path_to_previous_segment_video_output_for_guide and path_to_previous_segment_video_output_for_guide.startswith("http"):
+ try:
+ dprint(f"Seg {segment_idx}: Detected remote URL for previous segment: {path_to_previous_segment_video_output_for_guide}. Downloading...")
+ # Reuse download_file utility from common_utils
+ from ..common_utils import download_file as sm_download_file, sm_get_unique_target_path
+ remote_url = path_to_previous_segment_video_output_for_guide
+ local_filename = Path(remote_url).name
+ # Store under segment_processing_dir to keep things tidy
+ local_download_path = segment_processing_dir / f"prev_{segment_idx:02d}_{local_filename}"
+ # Ensure directory exists
+ segment_processing_dir.mkdir(parents=True, exist_ok=True)
+ # Perform download if file not already present
+ if not local_download_path.exists():
+ sm_download_file(remote_url, segment_processing_dir, local_download_path.name)
+ dprint(f"Seg {segment_idx}: Downloaded previous segment video to {local_download_path}")
+ else:
+ dprint(f"Seg {segment_idx}: Local copy of previous segment video already exists at {local_download_path}")
+ path_to_previous_segment_video_output_for_guide = str(local_download_path.resolve())
+ except Exception as e_dl_prev:
+ dprint(f"[WARNING] Seg {segment_idx}: Failed to download remote previous segment video: {e_dl_prev}")
+ # Leave path unchanged – will trigger the existing invalid path error below
+
+ if not path_to_previous_segment_video_output_for_guide or not Path(path_to_previous_segment_video_output_for_guide).exists():
+ error_detail_path = raw_path_from_db if 'raw_path_from_db' in locals() and raw_path_from_db else path_to_previous_segment_video_output_for_guide
+ msg = f"Seg {segment_idx}: Prev segment output for guide invalid/not found. Expected from prev task output. Path: {error_detail_path}"
+ print(f"[ERROR Task {segment_task_id_str}]: {msg}"); return False, msg
+
+ try: # Guide Video Creation Block
+ guide_video_base_name = f"{orchestrator_run_id}_seg{segment_idx:02d}_guide"
+ input_images_resolved_original = full_orchestrator_payload["input_image_paths_resolved"]
+
+ # Guide video path will be handled by sm_create_guide_video_for_travel_segment using centralized logic
+ guide_video_target_dir = segment_processing_dir
+ dprint(f"Seg {segment_idx} (Task {segment_task_id_str}): Guide video will be created in {guide_video_target_dir}")
+
+ # The download is now handled inside sm_create_guide_video_for_travel_segment (via sm_image_to_frame)
+ # Just pass the original paths.
+ input_images_resolved_for_guide = input_images_resolved_original
+
+ end_anchor_img_path_str_idx = segment_idx + 1
+ if full_orchestrator_payload.get("continue_from_video_resolved_path"):
+ end_anchor_img_path_str_idx = segment_idx
+
+ # SAFETY: If this is the last segment (i.e. there is no subsequent image),
+ # the "+1" logic above will be out of range. When that happens we
+ # signal the guide-creator that there is *no* end anchor by passing -1.
+ if end_anchor_img_path_str_idx >= len(input_images_resolved_for_guide):
+ end_anchor_img_path_str_idx = -1 # No end anchor image available.
+
+ # For a single image journey, there is no end anchor image.
+ if is_single_image_journey:
+ end_anchor_img_path_str_idx = -1 # Signal no end image to the creation function
+
+ # ------------------------------------------------------------------
+ # Show-input-images (banner) – determine start/end images now so the
+ # information can be propagated downstream to the chaining stage.
+ # ------------------------------------------------------------------
+ show_input_images_enabled = bool(full_orchestrator_payload.get("show_input_images", False))
+ start_image_for_banner = None
+ end_image_for_banner = None
+ if show_input_images_enabled:
+ try:
+ # For banner overlay, always show the first and last images of the entire journey
+ # This provides consistent context across all segments
+ if len(input_images_resolved_original) > 0:
+ start_image_for_banner = input_images_resolved_original[0] # Always first image
+
+ if len(input_images_resolved_original) > 1:
+ end_image_for_banner = input_images_resolved_original[-1] # Always last image
+ elif len(input_images_resolved_original) == 1:
+ # Single image journey - use the same image for both
+ end_image_for_banner = input_images_resolved_original[0]
+
+ # Ensure both banner images are local paths (download if URL)
+ if start_image_for_banner:
+ start_image_for_banner = sm_download_image_if_url(
+ start_image_for_banner,
+ segment_processing_dir,
+ segment_task_id_str,
+ )
+ if end_image_for_banner:
+ end_image_for_banner = sm_download_image_if_url(
+ end_image_for_banner,
+ segment_processing_dir,
+ segment_task_id_str,
+ )
+
+ except Exception as e_banner_sel:
+ dprint(f"Seg {segment_idx}: Error selecting banner images for show_input_images: {e_banner_sel}")
+ # ------------------------------------------------------------------
+
+ actual_guide_video_path_for_wgp = sm_create_guide_video_for_travel_segment(
+ segment_idx_for_logging=segment_idx,
+ end_anchor_image_index=end_anchor_img_path_str_idx,
+ is_first_segment_from_scratch=is_first_segment_from_scratch,
+ total_frames_for_segment=total_frames_for_segment,
+ parsed_res_wh=parsed_res_wh,
+ fps_helpers=fps_helpers,
+ input_images_resolved_for_guide=input_images_resolved_for_guide,
+ path_to_previous_segment_video_output_for_guide=path_to_previous_segment_video_output_for_guide,
+ output_target_dir=guide_video_target_dir,
+ guide_video_base_name=guide_video_base_name,
+ segment_image_download_dir=segment_image_download_dir,
+ task_id_for_logging=segment_task_id_str, # Corrected keyword argument
+ full_orchestrator_payload=full_orchestrator_payload,
+ segment_params=segment_params,
+ single_image_journey=is_single_image_journey,
+ dprint=dprint
+ )
+ except Exception as e_guide:
+ print(f"ERROR Task {segment_task_id_str} guide prep: {e_guide}")
+ traceback.print_exc()
+ actual_guide_video_path_for_wgp = None
+ # --- Invoke WGP Generation directly ---
+ if actual_guide_video_path_for_wgp is None and not is_first_segment_from_scratch:
+ # If guide creation failed AND it was essential (i.e., for any segment except the very first one from scratch)
+ msg = f"Task {segment_task_id_str}: Essential guide video failed to generate. Cannot proceed with WGP processing."
+ print(f"[ERROR] {msg}")
+ return False, msg
+
+ final_frames_for_wgp_generation = total_frames_for_segment
+ current_wgp_engine = "wgp" # Defaulting to WGP for travel segments
+
+ print(f"[WGP_DEBUG] Segment {segment_idx}: GENERATION PARAMETERS")
+ print(f"[WGP_DEBUG] final_frames_for_wgp_generation: {final_frames_for_wgp_generation}")
+ print(f"[WGP_DEBUG] parsed_res_wh: {parsed_res_wh}")
+ print(f"[WGP_DEBUG] fps_helpers: {fps_helpers}")
+ print(f"[WGP_DEBUG] model_name: {full_orchestrator_payload['model_name']}")
+ print(f"[WGP_DEBUG] use_causvid_lora: {full_orchestrator_payload.get('apply_causvid', False)}")
+
+ dprint(f"Task {segment_task_id_str}: Requesting WGP generation with {final_frames_for_wgp_generation} frames.")
+
+ if final_frames_for_wgp_generation <= 0:
+ msg = f"Task {segment_task_id_str}: Calculated WGP frames {final_frames_for_wgp_generation}. Cannot generate. Check segment_frames_target and overlap."
+ print(f"[ERROR] {msg}")
+ return False, msg
+
+ # The WGP task will run with a unique ID, but it's processed in-line now
+ wgp_inline_task_id = sm_generate_unique_task_id(f"wgp_inline_{segment_task_id_str[:8]}_")
+
+ # Define the absolute final output path for the WGP generation by process_single_task.
+ # If DB_TYPE is SQLite, process_single_task will ignore this and save to public/files, returning a relative path.
+ # If not SQLite, process_single_task will use this path (or its default construction) and return an absolute path.
+ wgp_video_filename = f"{orchestrator_run_id}_seg{segment_idx:02d}_output.mp4"
+ # For non-SQLite, wgp_final_output_path_for_this_segment is a suggestion for process_single_task
+ # For SQLite, this specific path isn't strictly used by process_single_task for its *final* save, but can be logged.
+ wgp_final_output_path_for_this_segment = segment_processing_dir / wgp_video_filename
+
+ safe_vace_image_ref_paths_for_wgp = [str(p.resolve()) if p else None for p in actual_vace_image_ref_paths_for_wgp]
+ safe_vace_image_ref_paths_for_wgp = [p for p in safe_vace_image_ref_paths_for_wgp if p is not None]
+
+ # If no image refs, pass None instead of empty list to avoid WGP VAE encoder issues
+ if not safe_vace_image_ref_paths_for_wgp:
+ safe_vace_image_ref_paths_for_wgp = None
+
+ current_segment_base_prompt = segment_params.get("base_prompt", " ")
+ prompt_for_wgp = ensure_valid_prompt(current_segment_base_prompt)
+ negative_prompt_for_wgp = ensure_valid_negative_prompt(segment_params.get("negative_prompt", " "))
+
+ dprint(f"Seg {segment_idx} (Task {segment_task_id_str}): Effective prompt for WGP: '{prompt_for_wgp}'")
+
+ # Compute video_prompt_type for wgp: always include 'V' when a guide video is provided; add 'M' if a mask video is also attached.
+ # Additionally, include 'I' when reference images are supplied so that VACE models
+ # properly process the `image_refs` list. Passing image refs without the 'I'
+ # flag causes Wan2GP to attempt to pre-process the paths as PIL images and
+ # raises AttributeError ('str' object has no attribute size').
+ video_prompt_type_str = (
+ "V" +
+ ("M" if mask_video_path_for_wgp else "") +
+ ("I" if safe_vace_image_ref_paths_for_wgp else "")
+ )
+
+ wgp_payload = {
+ "task_id": wgp_inline_task_id, # ID for this specific WGP generation operation
+ "model": full_orchestrator_payload["model_name"],
+ "prompt": prompt_for_wgp, # Use the processed prompt_for_wgp
+ "negative_prompt": segment_params["negative_prompt"],
+ "resolution": f"{parsed_res_wh[0]}x{parsed_res_wh[1]}", # Use parsed tuple here
+ "frames": final_frames_for_wgp_generation,
+ "seed": segment_params["seed_to_use"],
+ # output_path for process_single_task:
+ # - If SQLite, it's ignored, output goes to public/files, and a relative path is returned.
+ # - If not SQLite, this suggested path (or process_single_task's default) is used, and an absolute path is returned.
+ "output_path": str(wgp_final_output_path_for_this_segment.resolve()),
+ "video_guide_path": str(actual_guide_video_path_for_wgp.resolve()) if actual_guide_video_path_for_wgp and actual_guide_video_path_for_wgp.exists() else None,
+ "use_causvid_lora": full_orchestrator_payload.get("apply_causvid", False),
+ "apply_reward_lora": full_orchestrator_payload.get("apply_reward_lora", False),
+ "cfg_star_switch": full_orchestrator_payload.get("cfg_star_switch", 0),
+ "cfg_zero_step": full_orchestrator_payload.get("cfg_zero_step", -1),
+ "image_refs_paths": safe_vace_image_ref_paths_for_wgp,
+ # Propagate video_prompt_type so VACE model correctly interprets guide and mask inputs
+ "video_prompt_type": video_prompt_type_str,
+ # Attach mask video if available
+ **({"video_mask": str(mask_video_path_for_wgp.resolve())} if mask_video_path_for_wgp else {}),
+ }
+ if additional_loras:
+ wgp_payload["additional_loras"] = additional_loras
+
+ if full_orchestrator_payload.get("params_json_str_override"):
+ try:
+ additional_p = json.loads(full_orchestrator_payload["params_json_str_override"])
+ # Ensure override cannot change key params that indirectly control output length or resolution
+ additional_p.pop("frames", None); additional_p.pop("video_length", None)
+ additional_p.pop("steps", None); additional_p.pop("num_inference_steps", None)
+ additional_p.pop("resolution", None); additional_p.pop("output_path", None)
+ wgp_payload.update(additional_p)
+ except Exception as e_json: dprint(f"Error merging override params for WGP payload: {e_json}")
+
+ # Add travel_chain_details so process_single_task can call _handle_travel_chaining_after_wgp
+ wgp_payload["travel_chain_details"] = {
+ "orchestrator_task_id_ref": orchestrator_task_id_ref,
+ "orchestrator_run_id": orchestrator_run_id,
+ "segment_index_completed": segment_idx,
+ "is_last_segment_in_sequence": segment_params["is_last_segment"],
+ "current_run_base_output_dir": str(current_run_base_output_dir.resolve()),
+ "full_orchestrator_payload": full_orchestrator_payload,
+ "segment_processing_dir_for_saturation": str(segment_processing_dir.resolve()),
+ "is_first_new_segment_after_continue": is_first_new_segment_after_continue,
+ "is_subsequent_segment": is_subsequent_segment,
+ "colour_match_videos": effective_colour_match_enabled,
+ "cm_start_ref_path": start_ref_path_for_cm,
+ "cm_end_ref_path": end_ref_path_for_cm,
+ "show_input_images": show_input_images_enabled,
+ "start_image_path": start_image_for_banner,
+ "end_image_path": end_image_for_banner,
+ }
+
+ dprint(f"Seg {segment_idx} (Task {segment_task_id_str}): Invoking WGP generation via centralized wrapper (task_id for WGP op: {wgp_inline_task_id})")
+
+ # Process additional LoRAs using shared function
+ processed_additional_loras = {}
+ if additional_loras:
+ dprint(f"Seg {segment_idx}: Processing additional LoRAs using shared function")
+ model_filename_for_task = wgp_mod.get_model_filename(
+ full_orchestrator_payload["model_name"],
+ wgp_mod.transformer_quantization,
+ wgp_mod.transformer_dtype_policy,
+ )
+ processed_additional_loras = process_additional_loras_shared(
+ additional_loras,
+ wgp_mod,
+ model_filename_for_task,
+ segment_task_id_str,
+ dprint
+ )
+
+ # ------------------------------------------------------------------
+ # Ensure sensible defaults for critical generation params
+ # ------------------------------------------------------------------
+ lighti2x_enabled = bool(segment_params.get("use_lighti2x_lora", False) or full_orchestrator_payload.get("use_lighti2x_lora", False))
+
+ num_inference_steps = (
+ segment_params.get("num_inference_steps")
+ or segment_params.get("steps") # Check for "steps" as alternative
+ or full_orchestrator_payload.get("num_inference_steps")
+ or full_orchestrator_payload.get("steps") # Check for "steps" as alternative
+ or wgp_payload.get("num_inference_steps") # Check wgp_payload after JSON override
+ or wgp_payload.get("steps") # Check for "steps" in wgp_payload
+ or (5 if lighti2x_enabled else 30)
+ )
+
+ if lighti2x_enabled:
+ guidance_scale_default = 1.0
+ flow_shift_default = 5.0
+ else:
+ guidance_scale_default = full_orchestrator_payload.get("guidance_scale", 5.0)
+ flow_shift_default = full_orchestrator_payload.get("flow_shift", 3.0)
+
+ # Use the centralized WGP wrapper instead of process_single_task
+ generation_success, wgp_output_path_or_msg = generate_single_video(
+ wgp_mod=wgp_mod,
+ task_id=wgp_inline_task_id,
+ prompt=prompt_for_wgp,
+ negative_prompt=negative_prompt_for_wgp,
+ resolution=f"{parsed_res_wh[0]}x{parsed_res_wh[1]}",
+ video_length=final_frames_for_wgp_generation,
+ seed=segment_params["seed_to_use"],
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale_default,
+ flow_shift=flow_shift_default,
+ # Resolve the actual model .safetensors file that WGP expects – the
+ # orchestrator payload only contains the shorthand model alias
+ # (e.g. "vace_14B").
+ model_filename=wgp_mod.get_model_filename(
+ full_orchestrator_payload["model_name"],
+ wgp_mod.transformer_quantization,
+ wgp_mod.transformer_dtype_policy,
+ ),
+ video_guide=str(actual_guide_video_path_for_wgp.resolve()) if actual_guide_video_path_for_wgp and actual_guide_video_path_for_wgp.exists() else None,
+ video_mask=str(mask_video_path_for_wgp.resolve()) if mask_video_path_for_wgp else None,
+ image_refs=safe_vace_image_ref_paths_for_wgp,
+ use_causvid_lora=full_orchestrator_payload.get("apply_causvid", False) or full_orchestrator_payload.get("use_causvid_lora", False),
+ use_lighti2x_lora=full_orchestrator_payload.get("use_lighti2x_lora", False) or segment_params.get("use_lighti2x_lora", False),
+ apply_reward_lora=effective_apply_reward_lora,
+ additional_loras=processed_additional_loras,
+ video_prompt_type=video_prompt_type_str,
+ dprint=dprint,
+ **({k: v for k, v in wgp_payload.items() if k not in [
+ 'task_id', 'prompt', 'negative_prompt', 'resolution', 'frames', 'seed',
+ 'model', 'model_filename', 'video_guide', 'video_mask', 'image_refs',
+ 'use_causvid_lora', 'apply_reward_lora', 'additional_loras', 'video_prompt_type',
+ 'use_lighti2x_lora',
+ 'num_inference_steps', 'guidance_scale', 'flow_shift'
+ ]})
+ )
+
+ print(f"[WGP_DEBUG] Segment {segment_idx}: GENERATION RESULT")
+ print(f"[WGP_DEBUG] generation_success: {generation_success}")
+ print(f"[WGP_DEBUG] wgp_output_path_or_msg: {wgp_output_path_or_msg}")
+
+ # Analyze the WGP output if successful
+ if generation_success and wgp_output_path_or_msg:
+ wgp_debug_info = debug_video_analysis(wgp_output_path_or_msg, f"WGP_RAW_OUTPUT_Seg{segment_idx}", segment_task_id_str)
+ print(f"[WGP_DEBUG] Expected frames: {final_frames_for_wgp_generation}")
+ print(f"[WGP_DEBUG] Actual frames: {wgp_debug_info.get('frame_count', 'ERROR')}")
+ if wgp_debug_info.get('frame_count') != final_frames_for_wgp_generation:
+ print(f"[WGP_DEBUG] ⚠️ FRAME COUNT MISMATCH! Expected {final_frames_for_wgp_generation}, got {wgp_debug_info.get('frame_count')}")
+
+ if generation_success:
+ # Apply post-processing chain (saturation, brightness, color matching)
+ chain_success, chain_message, final_chained_path = _handle_travel_chaining_after_wgp(
+ wgp_task_params=wgp_payload,
+ actual_wgp_output_video_path=wgp_output_path_or_msg,
+ wgp_mod=wgp_mod,
+ image_download_dir=segment_image_download_dir,
+ dprint=dprint
+ )
+
+ if chain_success and final_chained_path:
+ final_segment_video_output_path_str = final_chained_path
+ output_message_for_segment_task = f"Segment {segment_idx} processing (WGP generation & chaining) completed. Final output: {final_segment_video_output_path_str}"
+
+ # Analyze final chained output
+ final_debug_info = debug_video_analysis(final_chained_path, f"FINAL_CHAINED_Seg{segment_idx}", segment_task_id_str)
+ print(f"[CHAIN_DEBUG] Segment {segment_idx}: FINAL CHAINED OUTPUT ANALYSIS")
+ print(f"[CHAIN_DEBUG] Expected frames: {final_frames_for_wgp_generation}")
+ print(f"[CHAIN_DEBUG] Final frames: {final_debug_info.get('frame_count', 'ERROR')}")
+ if final_debug_info.get('frame_count') != final_frames_for_wgp_generation:
+ print(f"[CHAIN_DEBUG] ⚠️ CHAINING CHANGED FRAME COUNT! Expected {final_frames_for_wgp_generation}, got {final_debug_info.get('frame_count')}")
+ else:
+ # Use raw WGP output if chaining failed
+ final_segment_video_output_path_str = wgp_output_path_or_msg
+ output_message_for_segment_task = f"Segment {segment_idx} WGP completed but chaining failed: {chain_message}. Using raw output: {final_segment_video_output_path_str}"
+ print(f"[WARNING] {output_message_for_segment_task}")
+
+ # Analyze raw output being used as final
+ if wgp_output_path_or_msg:
+ raw_debug_info = debug_video_analysis(wgp_output_path_or_msg, f"RAW_AS_FINAL_Seg{segment_idx}", segment_task_id_str)
+
+ print(f"Seg {segment_idx} (Task {segment_task_id_str}): {output_message_for_segment_task}")
+ else:
+ # wgp_output_path_or_msg contains the error message if generation_success is False
+ final_segment_video_output_path_str = None
+ output_message_for_segment_task = f"Segment {segment_idx} (Task {segment_task_id_str}) processing (WGP generation) failed. Error: {wgp_output_path_or_msg}"
+ print(f"[ERROR] {output_message_for_segment_task}")
+
+ # Notify orchestrator of segment failure
+ try:
+ db_ops.update_task_status(
+ orchestrator_task_id_ref,
+ db_ops.STATUS_FAILED,
+ output_message_for_segment_task[:500] # Truncate to avoid DB overflow
+ )
+ dprint(f"Segment {segment_idx}: Marked orchestrator task {orchestrator_task_id_ref} as FAILED due to WGP generation failure")
+ except Exception as e_orch:
+ dprint(f"Segment {segment_idx}: Warning - could not update orchestrator status: {e_orch}")
+
+ # The old polling logic is no longer needed as process_single_task is synchronous here.
+
+ # The return value final_segment_video_output_path_str (if success) is the one that
+ # process_single_task itself would have set as 'output_location' for the WGP task.
+ # Now, it becomes the output_location for the parent travel_segment task.
+ return generation_success, final_segment_video_output_path_str if generation_success else output_message_for_segment_task
+
+ except Exception as e:
+ print(f"ERROR Task {segment_task_id_str}: Unexpected error during segment processing: {e}")
+ traceback.print_exc()
+
+ # Notify orchestrator of segment failure
+ if 'orchestrator_task_id_ref' in locals() and orchestrator_task_id_ref:
+ try:
+ error_msg = f"Segment {segment_idx if 'segment_idx' in locals() else 'unknown'} failed: {str(e)[:200]}"
+ db_ops.update_task_status(
+ orchestrator_task_id_ref,
+ db_ops.STATUS_FAILED,
+ error_msg
+ )
+ dprint(f"Segment: Marked orchestrator task {orchestrator_task_id_ref} as FAILED due to exception")
+ except Exception as e_orch:
+ dprint(f"Segment: Warning - could not update orchestrator status: {e_orch}")
+
+ # return False, f"Unexpected error: {str(e)[:200]}"
+ return False, f"Segment {segment_idx if 'segment_idx' in locals() else 'unknown'} failed: {str(e)[:200]}"
+
+# --- SM_RESTRUCTURE: New function to handle chaining after WGP/Comfy sub-task ---
+def _handle_travel_chaining_after_wgp(wgp_task_params: dict, actual_wgp_output_video_path: str | None, wgp_mod, image_download_dir: Path | str | None = None, *, dprint) -> tuple[bool, str, str | None]:
+ """
+ Handles the chaining logic after a WGP sub-task for a travel segment completes.
+ This includes post-generation saturation and enqueuing the next segment or stitch task.
+ Returns: (success_bool, message_str, final_video_path_for_db_str_or_none)
+ The third element is the path that should be considered the definitive output of the WGP task
+ (e.g., path to saturated video if saturation was applied).
+ """
+ chain_details = wgp_task_params.get("travel_chain_details")
+ wgp_task_id = wgp_task_params.get("task_id", "unknown_wgp_task")
+
+ if not chain_details:
+ return False, f"Task {wgp_task_id}: Missing travel_chain_details. Cannot proceed with chaining.", None
+
+ # actual_wgp_output_video_path comes from process_single_task.
+ # If DB_TYPE is sqlite, it will be like "files/wgp_output.mp4".
+ # Otherwise, it's an absolute path.
+ if not actual_wgp_output_video_path: # Check if it's None or empty string
+ return False, f"Task {wgp_task_id}: WGP output video path is None or empty. Cannot chain.", None
+
+ # This variable will track the absolute path of the video as it gets processed.
+ video_to_process_abs_path: Path
+ # This will hold the path to be stored in the DB (can be relative for SQLite).
+ final_video_path_for_db = actual_wgp_output_video_path
+
+ # Resolve initial absolute path
+ if db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH and isinstance(actual_wgp_output_video_path, str) and actual_wgp_output_video_path.startswith("files/"):
+ sqlite_db_parent = Path(db_ops.SQLITE_DB_PATH).resolve().parent
+ video_to_process_abs_path = sqlite_db_parent / "public" / actual_wgp_output_video_path
+ else:
+ video_to_process_abs_path = Path(actual_wgp_output_video_path)
+
+ if not video_to_process_abs_path.exists():
+ return False, f"Task {wgp_task_id}: Source video for chaining '{video_to_process_abs_path}' (from '{actual_wgp_output_video_path}') does not exist.", actual_wgp_output_video_path
+
+ try:
+ orchestrator_task_id_ref = chain_details["orchestrator_task_id_ref"]
+ orchestrator_run_id = chain_details["orchestrator_run_id"]
+ segment_idx_completed = chain_details["segment_index_completed"]
+ full_orchestrator_payload = chain_details["full_orchestrator_payload"]
+ segment_processing_dir_for_saturation_str = chain_details["segment_processing_dir_for_saturation"]
+
+ is_first_new_segment_after_continue = chain_details.get("is_first_new_segment_after_continue", False)
+ is_subsequent_segment_val = chain_details.get("is_subsequent_segment", False)
+
+ dprint(f"Chaining for WGP task {wgp_task_id} (segment {segment_idx_completed} of run {orchestrator_run_id}). Initial video: {video_to_process_abs_path}")
+
+ # --- Always move WGP output to proper location first ---
+ # For SQLite, this moves the file from outputs/ to public/files/
+ # For other DBs, this ensures consistent file management
+ moved_filename = f"{orchestrator_run_id}_seg{segment_idx_completed:02d}_output{video_to_process_abs_path.suffix}"
+ moved_video_abs_path, moved_db_path = prepare_output_path(
+ task_id=wgp_task_id,
+ filename=moved_filename,
+ main_output_dir_base=Path(segment_processing_dir_for_saturation_str)
+ )
+
+ # Copy the WGP output to the proper location
+ try:
+ # Ensure encoder has finished writing the source file
+ sm_wait_for_file_stable(video_to_process_abs_path, checks=3, interval=1.0, dprint=dprint)
+
+ shutil.copy2(video_to_process_abs_path, moved_video_abs_path)
+ print(f"[CHAIN_DEBUG] Moved WGP output from {video_to_process_abs_path} to {moved_video_abs_path}")
+ debug_video_analysis(moved_video_abs_path, f"MOVED_WGP_OUTPUT_Seg{segment_idx_completed}", wgp_task_id)
+ dprint(f"Chain (Seg {segment_idx_completed}): Moved WGP output from {video_to_process_abs_path} to {moved_video_abs_path}")
+
+ # Update paths for further processing
+ video_to_process_abs_path = moved_video_abs_path
+ final_video_path_for_db = moved_db_path
+
+ # Clean up original WGP output if not in debug mode
+ if not full_orchestrator_payload.get("skip_cleanup_enabled", False) and \
+ not full_orchestrator_payload.get("debug_mode_enabled", False):
+ try:
+ Path(actual_wgp_output_video_path).unlink()
+ dprint(f"Chain (Seg {segment_idx_completed}): Cleaned up original WGP output {actual_wgp_output_video_path}")
+ except Exception as e_cleanup:
+ dprint(f"Chain (Seg {segment_idx_completed}): Warning - could not clean up original WGP output: {e_cleanup}")
+
+ except Exception as e_move:
+ dprint(f"Chain (Seg {segment_idx_completed}): Warning - could not move WGP output to proper location: {e_move}. Using original path.")
+
+ # --- Post-generation Processing Chain ---
+ # Saturation and Brightness are only applied to segments AFTER the first one.
+ if is_subsequent_segment_val or is_first_new_segment_after_continue:
+
+ # --- 1. Saturation ---
+ sat_level = full_orchestrator_payload.get("after_first_post_generation_saturation")
+ if sat_level is not None and isinstance(sat_level, (float, int)) and sat_level >= 0.0 and abs(sat_level - 1.0) > 1e-6:
+ dprint(f"Chain (Seg {segment_idx_completed}): Applying post-gen saturation {sat_level} to {video_to_process_abs_path}")
+
+ sat_filename = f"s{segment_idx_completed}_sat_{sat_level:.2f}{video_to_process_abs_path.suffix}"
+ saturated_video_output_abs_path, new_db_path = prepare_output_path(
+ task_id=wgp_task_id,
+ filename=sat_filename,
+ main_output_dir_base=Path(segment_processing_dir_for_saturation_str)
+ )
+
+ if sm_apply_saturation_to_video_ffmpeg(str(video_to_process_abs_path), saturated_video_output_abs_path, sat_level):
+ print(f"[CHAIN_DEBUG] Saturation applied successfully to segment {segment_idx_completed}")
+ debug_video_analysis(saturated_video_output_abs_path, f"SATURATED_Seg{segment_idx_completed}", wgp_task_id)
+ dprint(f"Chain (Seg {segment_idx_completed}): Saturation successful. New path: {new_db_path}")
+ _cleanup_intermediate_video(full_orchestrator_payload, video_to_process_abs_path, segment_idx_completed, "raw", dprint)
+
+ video_to_process_abs_path = saturated_video_output_abs_path
+ final_video_path_for_db = new_db_path
+ else:
+ print(f"[CHAIN_DEBUG] WARNING: Saturation failed for segment {segment_idx_completed}")
+ dprint(f"[WARNING] Chain (Seg {segment_idx_completed}): Saturation failed. Continuing with unsaturated video.")
+
+ # --- 2. Brightness ---
+ brightness_adjust = full_orchestrator_payload.get("after_first_post_generation_brightness", 0.0)
+ if isinstance(brightness_adjust, (float, int)) and abs(brightness_adjust) > 1e-6:
+ dprint(f"Chain (Seg {segment_idx_completed}): Applying post-gen brightness {brightness_adjust} to {video_to_process_abs_path}")
+
+ bright_filename = f"s{segment_idx_completed}_bright_{brightness_adjust:+.2f}{video_to_process_abs_path.suffix}"
+ brightened_video_output_abs_path, new_db_path = prepare_output_path(
+ task_id=wgp_task_id,
+ filename=bright_filename,
+ main_output_dir_base=Path(segment_processing_dir_for_saturation_str)
+ )
+
+ processed_video = apply_brightness_to_video_frames(str(video_to_process_abs_path), brightened_video_output_abs_path, brightness_adjust, wgp_task_id)
+
+ if processed_video and processed_video.exists():
+ print(f"[CHAIN_DEBUG] Brightness adjustment applied successfully to segment {segment_idx_completed}")
+ debug_video_analysis(brightened_video_output_abs_path, f"BRIGHTENED_Seg{segment_idx_completed}", wgp_task_id)
+ dprint(f"Chain (Seg {segment_idx_completed}): Brightness adjustment successful. New path: {new_db_path}")
+ _cleanup_intermediate_video(full_orchestrator_payload, video_to_process_abs_path, segment_idx_completed, "saturated", dprint)
+
+ video_to_process_abs_path = brightened_video_output_abs_path
+ final_video_path_for_db = new_db_path
+ else:
+ print(f"[CHAIN_DEBUG] WARNING: Brightness adjustment failed for segment {segment_idx_completed}")
+ dprint(f"[WARNING] Chain (Seg {segment_idx_completed}): Brightness adjustment failed. Continuing with previous video version.")
+
+ # --- 3. Color Matching (Applied to all segments if enabled) ---
+ if chain_details.get("colour_match_videos"):
+ start_ref = chain_details.get("cm_start_ref_path")
+ end_ref = chain_details.get("cm_end_ref_path")
+ print(f"[CHAIN_DEBUG] Color matching requested for segment {segment_idx_completed}")
+ print(f"[CHAIN_DEBUG] Start ref: {start_ref}")
+ print(f"[CHAIN_DEBUG] End ref: {end_ref}")
+ dprint(f"Chain (Seg {segment_idx_completed}): Color matching requested. Start Ref: {start_ref}, End Ref: {end_ref}")
+
+ if start_ref and end_ref and Path(start_ref).exists() and Path(end_ref).exists():
+ cm_filename = f"s{segment_idx_completed}_colormatched{video_to_process_abs_path.suffix}"
+ cm_video_output_abs_path, new_db_path = prepare_output_path(
+ task_id=wgp_task_id,
+ filename=cm_filename,
+ main_output_dir_base=Path(segment_processing_dir_for_saturation_str)
+ )
+
+ matched_video_path = sm_apply_color_matching_to_video(
+ str(video_to_process_abs_path),
+ start_ref,
+ end_ref,
+ str(cm_video_output_abs_path),
+ dprint
+ )
+
+ if matched_video_path and Path(matched_video_path).exists():
+ print(f"[CHAIN_DEBUG] Color matching applied successfully to segment {segment_idx_completed}")
+ debug_video_analysis(Path(matched_video_path), f"COLORMATCHED_Seg{segment_idx_completed}", wgp_task_id)
+ dprint(f"Chain (Seg {segment_idx_completed}): Color matching successful. New path: {new_db_path}")
+ _cleanup_intermediate_video(full_orchestrator_payload, video_to_process_abs_path, segment_idx_completed, "pre-colormatch", dprint)
+
+ video_to_process_abs_path = Path(matched_video_path)
+ final_video_path_for_db = new_db_path
+ else:
+ print(f"[CHAIN_DEBUG] WARNING: Color matching failed for segment {segment_idx_completed}")
+ dprint(f"[WARNING] Chain (Seg {segment_idx_completed}): Color matching failed. Continuing with previous video version.")
+ else:
+ print(f"[CHAIN_DEBUG] WARNING: Color matching skipped - missing or invalid reference images")
+ dprint(f"[WARNING] Chain (Seg {segment_idx_completed}): Skipping color matching due to missing or invalid reference image paths.")
+
+ # --- 4. Optional: Overlay start/end images above the video ---
+ if chain_details.get("show_input_images"):
+ banner_start = chain_details.get("start_image_path")
+ banner_end = chain_details.get("end_image_path")
+ if banner_start and banner_end and Path(banner_start).exists() and Path(banner_end).exists():
+ banner_filename = f"s{segment_idx_completed}_with_inputs{video_to_process_abs_path.suffix}"
+ banner_video_abs_path, new_db_path = prepare_output_path(
+ task_id=wgp_task_id,
+ filename=banner_filename,
+ main_output_dir_base=Path(segment_processing_dir_for_saturation_str)
+ )
+
+ if sm_overlay_start_end_images_above_video(
+ start_image_path=banner_start,
+ end_image_path=banner_end,
+ input_video_path=str(video_to_process_abs_path),
+ output_video_path=str(banner_video_abs_path),
+ dprint=dprint,
+ ):
+ print(f"[CHAIN_DEBUG] Banner overlay applied successfully to segment {segment_idx_completed}")
+ debug_video_analysis(banner_video_abs_path, f"BANNER_OVERLAY_Seg{segment_idx_completed}", wgp_task_id)
+ dprint(f"Chain (Seg {segment_idx_completed}): Banner overlay successful. New path: {new_db_path}")
+ _cleanup_intermediate_video(full_orchestrator_payload, video_to_process_abs_path, segment_idx_completed, "pre-banner", dprint)
+
+ video_to_process_abs_path = banner_video_abs_path
+ final_video_path_for_db = new_db_path
+ else:
+ print(f"[CHAIN_DEBUG] WARNING: Banner overlay failed for segment {segment_idx_completed}")
+ dprint(f"[WARNING] Chain (Seg {segment_idx_completed}): Banner overlay failed. Keeping previous video version.")
+ else:
+ print(f"[CHAIN_DEBUG] WARNING: Banner overlay skipped - missing valid start/end images")
+ dprint(f"[WARNING] Chain (Seg {segment_idx_completed}): show_input_images enabled but valid start/end images not found.")
+
+ # The orchestrator has already enqueued all segment and stitch tasks.
+ print(f"[CHAIN_DEBUG] Chaining complete for segment {segment_idx_completed}")
+ print(f"[CHAIN_DEBUG] Final video path for DB: {final_video_path_for_db}")
+ debug_video_analysis(video_to_process_abs_path, f"FINAL_CHAINED_Seg{segment_idx_completed}", wgp_task_id)
+ msg = f"Chain (Seg {segment_idx_completed}): Post-WGP processing complete. Final path for this WGP task's output: {final_video_path_for_db}"
+ dprint(msg)
+ return True, msg, str(final_video_path_for_db)
+
+ except Exception as e_chain:
+ error_msg = f"Chain (Seg {chain_details.get('segment_index_completed', 'N/A')} for WGP {wgp_task_id}): Failed during chaining: {e_chain}"
+ print(f"[ERROR] {error_msg}")
+ traceback.print_exc()
+
+ # Notify orchestrator of chaining failure
+ orchestrator_task_id_ref = chain_details.get("orchestrator_task_id_ref") if chain_details else None
+ if orchestrator_task_id_ref:
+ try:
+ db_ops.update_task_status(
+ orchestrator_task_id_ref,
+ db_ops.STATUS_FAILED,
+ error_msg[:500] # Truncate to avoid DB overflow
+ )
+ dprint(f"Chain: Marked orchestrator task {orchestrator_task_id_ref} as FAILED due to chaining failure")
+ except Exception as e_orch:
+ dprint(f"Chain: Warning - could not update orchestrator status: {e_orch}")
+
+ return False, error_msg, str(final_video_path_for_db) # Return path as it was before error
+
+
+
+def _cleanup_intermediate_video(orchestrator_payload, video_path: Path, segment_idx: int, stage: str, dprint):
+ """Helper to cleanup intermediate video files during chaining."""
+ # Delete intermediates **only** when every cleanup-bypass flag is false.
+ # That now includes the headless-server global debug flag (db_ops.debug_mode)
+ # so that running the server with --debug automatically preserves files.
+ if (
+ not orchestrator_payload.get("skip_cleanup_enabled", False)
+ and not orchestrator_payload.get("debug_mode_enabled", False)
+ and not db_ops.debug_mode
+ and video_path.exists()
+ ):
+ try:
+ video_path.unlink()
+ dprint(f"Chain (Seg {segment_idx}): Removed intermediate '{stage}' video {video_path}")
+ except Exception as e_del:
+ dprint(f"Chain (Seg {segment_idx}): Warning - could not remove intermediate video {video_path}: {e_del}")
+
+def _handle_travel_stitch_task(task_params_from_db: dict, main_output_dir_base: Path, stitch_task_id_str: str, *, dprint):
+ print(f"[IMMEDIATE DEBUG] _handle_travel_stitch_task: Starting for {stitch_task_id_str}")
+ print(f"[IMMEDIATE DEBUG] task_params_from_db keys: {list(task_params_from_db.keys())}")
+ print(f"[IMMEDIATE DEBUG] DB_TYPE: {db_ops.DB_TYPE}")
+
+ dprint(f"_handle_travel_stitch_task: Starting for {stitch_task_id_str}")
+ dprint(f"Stitch task_params_from_db (first 1000 chars): {json.dumps(task_params_from_db, default=str, indent=2)[:1000]}...")
+ stitch_params = task_params_from_db # This now contains full_orchestrator_payload
+ stitch_success = False
+ final_video_location_for_db = None
+
+ try:
+ # --- 1. Initialization & Parameter Extraction ---
+ orchestrator_task_id_ref = stitch_params.get("orchestrator_task_id_ref")
+ orchestrator_run_id = stitch_params.get("orchestrator_run_id")
+ full_orchestrator_payload = stitch_params.get("full_orchestrator_payload")
+
+ print(f"[IMMEDIATE DEBUG] orchestrator_run_id: {orchestrator_run_id}")
+ print(f"[IMMEDIATE DEBUG] orchestrator_task_id_ref: {orchestrator_task_id_ref}")
+ print(f"[IMMEDIATE DEBUG] full_orchestrator_payload present: {full_orchestrator_payload is not None}")
+
+ if not all([orchestrator_task_id_ref, orchestrator_run_id, full_orchestrator_payload]):
+ msg = f"Stitch task {stitch_task_id_str} missing critical orchestrator refs or full_orchestrator_payload."
+ print(f"[ERROR Task {stitch_task_id_str}]: {msg}")
+ return False, msg
+
+ project_id_for_stitch = stitch_params.get("project_id")
+ current_run_base_output_dir_str = stitch_params.get("current_run_base_output_dir",
+ full_orchestrator_payload.get("main_output_dir_for_run", str(main_output_dir_base.resolve())))
+ current_run_base_output_dir = Path(current_run_base_output_dir_str)
+
+ # Use the base directory directly without creating stitch-specific subdirectories
+ stitch_processing_dir = current_run_base_output_dir
+ stitch_processing_dir.mkdir(parents=True, exist_ok=True)
+ dprint(f"Stitch Task {stitch_task_id_str}: Processing in {stitch_processing_dir.resolve()}")
+
+ num_expected_new_segments = full_orchestrator_payload["num_new_segments_to_generate"]
+ print(f"[IMMEDIATE DEBUG] num_expected_new_segments: {num_expected_new_segments}")
+
+ # Ensure parsed_res_wh is a tuple of integers for stitch task with model grid snapping
+ parsed_res_wh_str = full_orchestrator_payload["parsed_resolution_wh"]
+ try:
+ parsed_res_raw = sm_parse_resolution(parsed_res_wh_str)
+ if parsed_res_raw is None:
+ raise ValueError(f"sm_parse_resolution returned None for input: {parsed_res_wh_str}")
+ parsed_res_wh = snap_resolution_to_model_grid(parsed_res_raw)
+ except Exception as e_parse_res_stitch:
+ msg = f"Stitch Task {stitch_task_id_str}: Invalid format or error parsing parsed_resolution_wh '{parsed_res_wh_str}': {e_parse_res_stitch}"
+ print(f"[ERROR Task {stitch_task_id_str}]: {msg}"); return False, msg
+ dprint(f"Stitch Task {stitch_task_id_str}: Parsed resolution (w,h): {parsed_res_wh}")
+
+ final_fps = full_orchestrator_payload.get("fps_helpers", 16)
+ expanded_frame_overlaps = full_orchestrator_payload["frame_overlap_expanded"]
+ crossfade_sharp_amt = full_orchestrator_payload.get("crossfade_sharp_amt", 0.3)
+ initial_continued_video_path_str = full_orchestrator_payload.get("continue_from_video_resolved_path")
+
+ # [OVERLAP DEBUG] Add detailed debug for overlap values
+ print(f"[OVERLAP DEBUG] Stitch: expanded_frame_overlaps from payload: {expanded_frame_overlaps}")
+ dprint(f"[OVERLAP DEBUG] Stitch: expanded_frame_overlaps from payload: {expanded_frame_overlaps}")
+
+ # Extract upscale parameters
+ upscale_factor = full_orchestrator_payload.get("upscale_factor", 0.0) # Default to 0.0 if not present
+ upscale_model_name = full_orchestrator_payload.get("upscale_model_name") # Default to None if not present
+
+ # --- 2. Collect Paths to All Segment Videos ---
+ segment_video_paths_for_stitch = []
+ if initial_continued_video_path_str and Path(initial_continued_video_path_str).exists():
+ dprint(f"Stitch: Prepending initial continued video: {initial_continued_video_path_str}")
+ # Check the continue video properties
+ cap = cv2.VideoCapture(str(initial_continued_video_path_str))
+ if cap.isOpened():
+ continue_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ continue_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ continue_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ cap.release()
+ dprint(f"Stitch: Continue video properties - Resolution: {continue_width}x{continue_height}, Frames: {continue_frame_count}")
+ dprint(f"Stitch: Target resolution for stitching: {parsed_res_wh[0]}x{parsed_res_wh[1]}")
+ if continue_width != parsed_res_wh[0] or continue_height != parsed_res_wh[1]:
+ dprint(f"Stitch: WARNING - Continue video resolution mismatch! Will need resizing during crossfade.")
+ else:
+ dprint(f"Stitch: ERROR - Could not open continue video for property check")
+ segment_video_paths_for_stitch.append(str(Path(initial_continued_video_path_str).resolve()))
+
+ # Fetch completed segments with a small retry loop to handle race conditions
+ max_stitch_fetch_retries = 6 # Allow up to ~18s total wait
+ completed_segment_outputs_from_db = []
+
+ print(f"[IMMEDIATE DEBUG] About to start retry loop for run_id: {orchestrator_run_id}")
+
+ for attempt in range(max_stitch_fetch_retries):
+ print(f"[IMMEDIATE DEBUG] Stitch fetch attempt {attempt+1}/{max_stitch_fetch_retries} for run_id: {orchestrator_run_id}")
+ dprint(f"[DEBUG] Stitch fetch attempt {attempt+1}/{max_stitch_fetch_retries} for run_id: {orchestrator_run_id}")
+
+ try:
+ completed_segment_outputs_from_db = db_ops.get_completed_segment_outputs_for_stitch(orchestrator_run_id, project_id=project_id_for_stitch) or []
+ print(f"[IMMEDIATE DEBUG] DB query returned: {completed_segment_outputs_from_db}")
+ except Exception as e_db_query:
+ print(f"[IMMEDIATE DEBUG] DB query failed: {e_db_query}")
+ completed_segment_outputs_from_db = []
+
+ dprint(f"[DEBUG] Attempt {attempt+1} returned {len(completed_segment_outputs_from_db)} segments")
+ print(f"[IMMEDIATE DEBUG] Attempt {attempt+1} returned {len(completed_segment_outputs_from_db)} segments")
+
+ if len(completed_segment_outputs_from_db) >= num_expected_new_segments:
+ dprint(f"[DEBUG] Expected {num_expected_new_segments} segment rows found on attempt {attempt+1}. Proceeding.")
+ print(f"[IMMEDIATE DEBUG] Expected {num_expected_new_segments} segment rows found on attempt {attempt+1}. Proceeding.")
+ break
+ dprint(f"Stitch: No completed segment rows found (attempt {attempt+1}/{max_stitch_fetch_retries}). Waiting 3s and retrying...")
+ print(f"[IMMEDIATE DEBUG] Insufficient segments found (attempt {attempt+1}/{max_stitch_fetch_retries}). Waiting 3s and retrying...")
+ if attempt < max_stitch_fetch_retries - 1: # Don't sleep after the last attempt
+ time.sleep(3)
+ dprint(f"Stitch Task {stitch_task_id_str}: Completed segments fetched: {completed_segment_outputs_from_db}")
+ print(f"[IMMEDIATE DEBUG] Final completed_segment_outputs_from_db: {completed_segment_outputs_from_db}")
+
+ # ------------------------------------------------------------------
+ # 2b. Resolve each returned video path (local, SQLite-relative, or URL)
+ # ------------------------------------------------------------------
+ print(f"[STITCH_DEBUG] Starting path resolution for {len(completed_segment_outputs_from_db)} segments")
+ print(f"[STITCH_DEBUG] Raw DB results: {completed_segment_outputs_from_db}")
+ dprint(f"[DEBUG] Starting path resolution for {len(completed_segment_outputs_from_db)} segments")
+ for seg_idx, video_path_str_from_db in completed_segment_outputs_from_db:
+ print(f"[STITCH_DEBUG] Processing segment {seg_idx} with path: {video_path_str_from_db}")
+ dprint(f"[DEBUG] Processing segment {seg_idx} with path: {video_path_str_from_db}")
+ resolved_video_path_for_stitch: Path | None = None
+
+ if not video_path_str_from_db:
+ print(f"[STITCH_DEBUG] WARNING: Segment {seg_idx} has empty video_path in DB; skipping.")
+ dprint(f"[WARNING] Stitch: Segment {seg_idx} has empty video_path in DB; skipping.")
+ continue
+
+ # Case A: Relative path that starts with files/ (works for both sqlite and supabase when headless has local access)
+ if video_path_str_from_db.startswith("files/") or video_path_str_from_db.startswith("public/files/"):
+ print(f"[STITCH_DEBUG] Case A: Relative path detected for segment {seg_idx}")
+ sqlite_db_parent = None
+ if db_ops.SQLITE_DB_PATH:
+ sqlite_db_parent = Path(db_ops.SQLITE_DB_PATH).resolve().parent
+ else:
+ # Fall back: examine cwd and assume standard layout ../public
+ try:
+ sqlite_db_parent = Path.cwd()
+ except Exception:
+ sqlite_db_parent = Path(".")
+ absolute_path_candidate = (sqlite_db_parent / "public" / video_path_str_from_db.lstrip("public/")).resolve()
+ print(f"[STITCH_DEBUG] Resolved relative path '{video_path_str_from_db}' to '{absolute_path_candidate}' for segment {seg_idx}")
+ dprint(f"Stitch: Resolved relative path '{video_path_str_from_db}' to '{absolute_path_candidate}' for segment {seg_idx}")
+ if absolute_path_candidate.exists() and absolute_path_candidate.is_file():
+ resolved_video_path_for_stitch = absolute_path_candidate
+ print(f"[STITCH_DEBUG] ✅ File exists at resolved path")
+ else:
+ print(f"[STITCH_DEBUG] ❌ File missing at resolved path")
+ dprint(f"[WARNING] Stitch: Resolved absolute path '{absolute_path_candidate}' for segment {seg_idx} is missing.")
+
+ # Case B: Remote public URL (Supabase storage)
+ elif video_path_str_from_db.startswith("http"):
+ print(f"[STITCH_DEBUG] Case B: Remote URL detected for segment {seg_idx}")
+ try:
+ from ..common_utils import download_file as sm_download_file
+ remote_url = video_path_str_from_db
+ local_filename = Path(remote_url).name
+ local_download_path = stitch_processing_dir / f"seg{seg_idx:02d}_{local_filename}"
+ print(f"[STITCH_DEBUG] Remote URL: {remote_url}")
+ print(f"[STITCH_DEBUG] Local download path: {local_download_path}")
+ dprint(f"[DEBUG] Remote URL detected, local download path: {local_download_path}")
+
+ # Check if cached file exists and validate its frame count against orchestrator's expected values
+ need_download = True
+ if local_download_path.exists():
+ print(f"[STITCH_DEBUG] Local copy exists, validating frame count...")
+ try:
+ cached_frames, _ = sm_get_video_frame_count_and_fps(str(local_download_path))
+ expected_segment_frames = full_orchestrator_payload["segment_frames_expanded"]
+ expected_frames = expected_segment_frames[seg_idx] if seg_idx < len(expected_segment_frames) else None
+ print(f"[STITCH_DEBUG] Cached file has {cached_frames} frames (expected: {expected_frames})")
+
+ if expected_frames and cached_frames == expected_frames:
+ print(f"[STITCH_DEBUG] ✅ Cached file frame count matches expected ({cached_frames} frames)")
+ need_download = False
+ elif expected_frames:
+ print(f"[STITCH_DEBUG] ❌ Cached file frame count mismatch! Expected {expected_frames}, got {cached_frames}, will re-download")
+ else:
+ print(f"[STITCH_DEBUG] ❌ No expected frame count available for segment {seg_idx}, will re-download")
+ except Exception as e_validate:
+ print(f"[STITCH_DEBUG] ❌ Could not validate cached file: {e_validate}, will re-download")
+
+ if need_download:
+ print(f"[STITCH_DEBUG] Downloading remote segment {seg_idx}...")
+ dprint(f"Stitch: Downloading remote segment {seg_idx} from {remote_url} to {local_download_path}")
+ # Remove stale cached file if it exists
+ if local_download_path.exists():
+ local_download_path.unlink()
+ sm_download_file(remote_url, stitch_processing_dir, local_download_path.name)
+ print(f"[STITCH_DEBUG] ✅ Download completed for segment {seg_idx}")
+ dprint(f"[DEBUG] Download completed for segment {seg_idx}")
+ else:
+ print(f"[STITCH_DEBUG] ✅ Using validated cached file for segment {seg_idx}")
+ dprint(f"Stitch: Using validated cached file for segment {seg_idx} at {local_download_path}")
+
+ resolved_video_path_for_stitch = local_download_path
+ except Exception as e_dl:
+ print(f"[STITCH_DEBUG] ❌ Download failed for segment {seg_idx}: {e_dl}")
+ dprint(f"[WARNING] Stitch: Failed to download remote video for segment {seg_idx}: {e_dl}")
+
+ # Case C: Provided absolute/local path
+ else:
+ print(f"[STITCH_DEBUG] Case C: Absolute/local path for segment {seg_idx}")
+ absolute_path_candidate = Path(video_path_str_from_db).resolve()
+ print(f"[STITCH_DEBUG] Treating as absolute path: {absolute_path_candidate}")
+ dprint(f"[DEBUG] Treating as absolute path: {absolute_path_candidate}")
+ if absolute_path_candidate.exists() and absolute_path_candidate.is_file():
+ resolved_video_path_for_stitch = absolute_path_candidate
+ print(f"[STITCH_DEBUG] ✅ Absolute path exists")
+ dprint(f"[DEBUG] Absolute path exists: {absolute_path_candidate}")
+ else:
+ print(f"[STITCH_DEBUG] ❌ Absolute path missing or not a file")
+ dprint(f"[WARNING] Stitch: Absolute path '{absolute_path_candidate}' for segment {seg_idx} does not exist or is not a file.")
+
+ if resolved_video_path_for_stitch is not None:
+ segment_video_paths_for_stitch.append(str(resolved_video_path_for_stitch))
+ print(f"[STITCH_DEBUG] ✅ Added video for segment {seg_idx}: {resolved_video_path_for_stitch}")
+ dprint(f"Stitch: Added video for segment {seg_idx}: {resolved_video_path_for_stitch}")
+
+ # Analyze the resolved video immediately
+ debug_video_analysis(resolved_video_path_for_stitch, f"RESOLVED_Seg{seg_idx}", stitch_task_id_str)
+ else:
+ print(f"[STITCH_DEBUG] ❌ Unable to resolve video for segment {seg_idx}; will be excluded from stitching.")
+ dprint(f"[WARNING] Stitch: Unable to resolve video for segment {seg_idx}; will be excluded from stitching.")
+
+ print(f"[STITCH_DEBUG] Path resolution complete")
+ print(f"[STITCH_DEBUG] Final segment_video_paths_for_stitch: {segment_video_paths_for_stitch}")
+ print(f"[STITCH_DEBUG] Total videos collected: {len(segment_video_paths_for_stitch)}")
+ dprint(f"[DEBUG] Final segment_video_paths_for_stitch: {segment_video_paths_for_stitch}")
+ dprint(f"[DEBUG] Total videos collected: {len(segment_video_paths_for_stitch)}")
+ # [CRITICAL DEBUG] Log each video's frame count before stitching
+ print(f"[CRITICAL DEBUG] About to stitch videos:")
+ expected_segment_frames = full_orchestrator_payload["segment_frames_expanded"]
+ for idx, video_path in enumerate(segment_video_paths_for_stitch):
+ try:
+ frame_count, fps = sm_get_video_frame_count_and_fps(video_path)
+ expected_frames = expected_segment_frames[idx] if idx < len(expected_segment_frames) else "unknown"
+ print(f"[CRITICAL DEBUG] Video {idx}: {video_path} -> {frame_count} frames @ {fps} FPS (expected: {expected_frames})")
+ if expected_frames != "unknown" and frame_count != expected_frames:
+ print(f"[CRITICAL DEBUG] ⚠️ FRAME COUNT MISMATCH! Expected {expected_frames}, got {frame_count}")
+ except Exception as e_debug:
+ print(f"[CRITICAL DEBUG] Video {idx}: {video_path} -> ERROR: {e_debug}")
+
+ total_videos_for_stitch = (1 if initial_continued_video_path_str and Path(initial_continued_video_path_str).exists() else 0) + num_expected_new_segments
+ dprint(f"[DEBUG] Expected total videos: {total_videos_for_stitch}")
+ if len(segment_video_paths_for_stitch) < total_videos_for_stitch:
+ # This is a warning because some segments might have legitimately failed and been skipped by their handlers.
+ # The stitcher should proceed with what it has, unless it has zero or one video when multiple were expected.
+ dprint(f"[WARNING] Stitch: Expected {total_videos_for_stitch} videos for stitch, but found {len(segment_video_paths_for_stitch)}. Stitching with available videos.")
+
+ if not segment_video_paths_for_stitch:
+ dprint(f"[ERROR] Stitch: No valid segment videos found to stitch. DB returned {len(completed_segment_outputs_from_db)} segments, but none resolved to valid paths.")
+ raise ValueError("Stitch: No valid segment videos found to stitch.")
+ if len(segment_video_paths_for_stitch) == 1 and total_videos_for_stitch > 1:
+ dprint(f"Stitch: Only one video segment found ({segment_video_paths_for_stitch[0]}) but {total_videos_for_stitch} were expected. Using this single video as the 'stitched' output.")
+ # No actual stitching needed, just move/copy this single video to final dest.
+
+ # --- 3. Stitching (Crossfade or Concatenate) ---
+ current_stitched_video_path: Path | None = None # This will hold the path to the current version of the stitched video
+
+
+ if len(segment_video_paths_for_stitch) == 1:
+ # If only one video, copy it directly using prepare_output_path
+ source_single_video_path = Path(segment_video_paths_for_stitch[0])
+ single_video_filename = f"{orchestrator_run_id}_final{source_single_video_path.suffix}"
+
+ current_stitched_video_path, _ = prepare_output_path(
+ task_id=stitch_task_id_str,
+ filename=single_video_filename,
+ main_output_dir_base=stitch_processing_dir
+ )
+ shutil.copy2(str(source_single_video_path), str(current_stitched_video_path))
+ dprint(f"Stitch: Only one video found. Copied {source_single_video_path} to {current_stitched_video_path}")
+ else: # More than one video, proceed with stitching logic
+ num_stitch_points = len(segment_video_paths_for_stitch) - 1
+ actual_overlaps_for_stitching = []
+ if initial_continued_video_path_str:
+ actual_overlaps_for_stitching = expanded_frame_overlaps[:num_stitch_points]
+ else:
+ actual_overlaps_for_stitching = expanded_frame_overlaps[:num_stitch_points]
+
+ # --- NEW OVERLAP DEBUG LOGGING ---
+ print(f"[OVERLAP DEBUG] Number of videos: {len(segment_video_paths_for_stitch)} (expected stitch points: {num_stitch_points})")
+ print(f"[OVERLAP DEBUG] actual_overlaps_for_stitching: {actual_overlaps_for_stitching}")
+ if len(actual_overlaps_for_stitching) != num_stitch_points:
+ print(f"[OVERLAP DEBUG] ⚠️ MISMATCH! We have {len(actual_overlaps_for_stitching)} overlaps for {num_stitch_points} joins")
+ for join_idx, ov in enumerate(actual_overlaps_for_stitching):
+ print(f"[OVERLAP DEBUG] Join {join_idx} (video {join_idx} -> {join_idx+1}): overlap={ov}")
+ # --- END NEW LOGGING ---
+
+ any_positive_overlap = any(o > 0 for o in actual_overlaps_for_stitching)
+
+ raw_stitched_video_filename = f"{orchestrator_run_id}_stitched.mp4"
+ path_for_raw_stitched_video, _ = prepare_output_path(
+ task_id=stitch_task_id_str,
+ filename=raw_stitched_video_filename,
+ main_output_dir_base=stitch_processing_dir
+ )
+
+ if any_positive_overlap:
+ print(f"[CRITICAL DEBUG] Using cross-fade due to overlap values: {actual_overlaps_for_stitching}. Output to: {path_for_raw_stitched_video}")
+ print(f"[STITCH_ANALYSIS] Cross-fade stitching analysis:")
+ print(f"[STITCH_ANALYSIS] Number of videos: {len(segment_video_paths_for_stitch)}")
+ print(f"[STITCH_ANALYSIS] Overlap values: {actual_overlaps_for_stitching}")
+ print(f"[STITCH_ANALYSIS] Expected stitch points: {num_stitch_points}")
+
+ dprint(f"Stitch: Using cross-fade due to overlap values: {actual_overlaps_for_stitching}. Output to: {path_for_raw_stitched_video}")
+ all_segment_frames_lists = [sm_extract_frames_from_video(p, dprint_func=dprint) for p in segment_video_paths_for_stitch]
+
+ # [CRITICAL DEBUG] Log frame extraction results
+ print(f"[CRITICAL DEBUG] Frame extraction results:")
+ for idx, frame_list in enumerate(all_segment_frames_lists):
+ if frame_list is not None:
+ print(f"[CRITICAL DEBUG] Segment {idx}: {len(frame_list)} frames extracted")
+ else:
+ print(f"[CRITICAL DEBUG] Segment {idx}: FAILED to extract frames")
+
+ if not all(f_list is not None and len(f_list)>0 for f_list in all_segment_frames_lists):
+ raise ValueError("Stitch: Frame extraction failed for one or more segments during cross-fade prep.")
+
+ final_stitched_frames = []
+
+ # Process each stitch point
+ for i in range(num_stitch_points):
+ frames_prev_segment = all_segment_frames_lists[i]
+ frames_curr_segment = all_segment_frames_lists[i+1]
+ current_overlap_val = actual_overlaps_for_stitching[i]
+
+ print(f"[CRITICAL DEBUG] Stitch point {i}: segments {i}->{i+1}, overlap={current_overlap_val}")
+ print(f"[CRITICAL DEBUG] Prev segment: {len(frames_prev_segment)} frames, Curr segment: {len(frames_curr_segment)} frames")
+
+ # --- NEW OVERLAP DETAIL LOG ---
+ if current_overlap_val > 0:
+ start_prev = len(frames_prev_segment) - current_overlap_val
+ end_prev = len(frames_prev_segment) - 1
+ start_curr = 0
+ end_curr = current_overlap_val - 1
+ print(
+ f"[OVERLAP_DETAIL] Join {i}: blending prev[{start_prev}:{end_prev}] with curr[{start_curr}:{end_curr}] (total {current_overlap_val} frames)"
+ )
+ # --- END OVERLAP DETAIL LOG ---
+
+ if i == 0:
+ # For the first stitch point, add frames from segment 0 up to the overlap
+ if current_overlap_val > 0:
+ # Add frames before the overlap region
+ frames_before_overlap = frames_prev_segment[:-current_overlap_val]
+ final_stitched_frames.extend(frames_before_overlap)
+ print(f"[CRITICAL DEBUG] Added {len(frames_before_overlap)} frames from segment 0 (before overlap)")
+ else:
+ # No overlap, add all frames from segment 0
+ final_stitched_frames.extend(frames_prev_segment)
+ print(f"[CRITICAL DEBUG] Added all {len(frames_prev_segment)} frames from segment 0 (no overlap)")
+ else:
+ pass
+
+ if current_overlap_val > 0:
+ # Remove the overlap frames already appended from the previous segment so that
+ # they can be replaced by the blended cross-fade frames for this stitch point.
+ if i > 0:
+ frames_to_remove = min(current_overlap_val, len(final_stitched_frames))
+ if frames_to_remove > 0:
+ del final_stitched_frames[-frames_to_remove:]
+ print(f"[CRITICAL DEBUG] Removed {frames_to_remove} duplicate overlap frames before cross-fade (stitch point {i})")
+ # Blend the overlapping frames
+ faded_frames = sm_cross_fade_overlap_frames(frames_prev_segment, frames_curr_segment, current_overlap_val, "linear_sharp", crossfade_sharp_amt)
+ final_stitched_frames.extend(faded_frames)
+ print(f"[CRITICAL DEBUG] Added {len(faded_frames)} cross-faded frames")
+
+ # Add the non-overlapping part of the current segment
+ start_index_for_curr_tail = current_overlap_val
+ if len(frames_curr_segment) > start_index_for_curr_tail:
+ frames_to_add = frames_curr_segment[start_index_for_curr_tail:]
+ final_stitched_frames.extend(frames_to_add)
+ print(f"[CRITICAL DEBUG] Added {len(frames_to_add)} frames from segment {i+1} (after overlap)")
+
+ print(f"[CRITICAL DEBUG] Running total after stitch point {i}: {len(final_stitched_frames)} frames")
+
+ if not final_stitched_frames: raise ValueError("Stitch: No frames produced after cross-fade logic.")
+
+ # [CRITICAL DEBUG] Final calculation summary
+ # With proper cross-fade: output = sum(all frames) - sum(overlaps)
+ # Because overlapped frames are blended, not duplicated
+ total_input_frames = sum(len(frames) for frames in all_segment_frames_lists)
+ total_overlaps = sum(actual_overlaps_for_stitching)
+ expected_output_frames = total_input_frames - total_overlaps
+ actual_output_frames = len(final_stitched_frames)
+ print(f"[CRITICAL DEBUG] FINAL CROSS-FADE SUMMARY:")
+ print(f"[CRITICAL DEBUG] Total input frames: {total_input_frames}")
+ print(f"[CRITICAL DEBUG] Total overlaps: {total_overlaps}")
+ print(f"[CRITICAL DEBUG] Expected output: {expected_output_frames}")
+ print(f"[CRITICAL DEBUG] Actual output: {actual_output_frames}")
+ print(f"[CRITICAL DEBUG] Match: {expected_output_frames == actual_output_frames}")
+
+ created_video_path_obj = sm_create_video_from_frames_list(final_stitched_frames, path_for_raw_stitched_video, final_fps, parsed_res_wh)
+ if created_video_path_obj and created_video_path_obj.exists():
+ current_stitched_video_path = created_video_path_obj
+ else:
+ raise RuntimeError(f"Stitch: Cross-fade sm_create_video_from_frames_list failed to produce video at {path_for_raw_stitched_video}")
+
+ else:
+ dprint(f"Stitch: Using simple FFmpeg concatenation. Output to: {path_for_raw_stitched_video}")
+ try:
+ from .common_utils import stitch_videos_ffmpeg as sm_stitch_videos_ffmpeg
+ except ImportError:
+ print(f"[CRITICAL ERROR Task ID: {stitch_task_id_str}] Failed to import 'stitch_videos_ffmpeg'. Cannot proceed with stitching.")
+ raise
+
+ if sm_stitch_videos_ffmpeg(segment_video_paths_for_stitch, str(path_for_raw_stitched_video)):
+ current_stitched_video_path = path_for_raw_stitched_video
+ else:
+ raise RuntimeError(f"Stitch: Simple FFmpeg concatenation failed for output {path_for_raw_stitched_video}.")
+
+ if not current_stitched_video_path or not current_stitched_video_path.exists():
+ raise RuntimeError(f"Stitch: Stitching process failed, output video not found at {current_stitched_video_path}")
+
+ video_path_after_optional_upscale = current_stitched_video_path
+
+ if isinstance(upscale_factor, (float, int)) and upscale_factor > 1.0 and upscale_model_name:
+ print(f"[STITCH UPSCALE] Starting upscale process: {upscale_factor}x using model {upscale_model_name}")
+ dprint(f"Stitch: Upscaling (x{upscale_factor}) video {current_stitched_video_path.name} using model {upscale_model_name}")
+
+ original_frames_count, original_fps = sm_get_video_frame_count_and_fps(str(current_stitched_video_path))
+ if original_frames_count is None or original_frames_count == 0:
+ raise ValueError(f"Stitch: Cannot get frame count or 0 frames for video {current_stitched_video_path} before upscaling.")
+
+ print(f"[STITCH UPSCALE] Input video: {original_frames_count} frames @ {original_fps} FPS")
+ print(f"[STITCH UPSCALE] Target resolution: {int(parsed_res_wh[0] * upscale_factor)}x{int(parsed_res_wh[1] * upscale_factor)}")
+ dprint(f"[DEBUG] Pre-upscale analysis: {original_frames_count} frames, {original_fps} FPS")
+
+ target_width_upscaled = int(parsed_res_wh[0] * upscale_factor)
+ target_height_upscaled = int(parsed_res_wh[1] * upscale_factor)
+
+ upscale_sub_task_id = sm_generate_unique_task_id(f"upscale_stitch_{orchestrator_run_id}_")
+
+ upscale_payload = {
+ "task_id": upscale_sub_task_id,
+ "project_id": stitch_params.get("project_id"),
+ "model": upscale_model_name,
+ "video_source_path": str(current_stitched_video_path.resolve()),
+ "resolution": f"{target_width_upscaled}x{target_height_upscaled}",
+ "frames": original_frames_count,
+ "prompt": full_orchestrator_payload.get("original_task_args",{}).get("upscale_prompt", "cinematic, masterpiece, high detail, 4k"),
+ "seed": full_orchestrator_payload.get("seed_for_upscale", full_orchestrator_payload.get("seed_base", 12345) + 5000),
+ }
+
+ db_path_for_upscale_add = db_ops.SQLITE_DB_PATH if db_ops.DB_TYPE == "sqlite" else None
+ upscaler_engine_to_use = stitch_params.get("execution_engine_for_upscale", "wgp")
+
+ db_ops.add_task_to_db(
+ task_payload=upscale_payload,
+ task_type_str=upscaler_engine_to_use
+ )
+ print(f"[STITCH UPSCALE] Enqueued upscale sub-task {upscale_sub_task_id} ({upscaler_engine_to_use}). Waiting...")
+ print(f"Stitch Task {stitch_task_id_str}: Enqueued upscale sub-task {upscale_sub_task_id} ({upscaler_engine_to_use}). Waiting...")
+
+ poll_interval_ups = full_orchestrator_payload.get("poll_interval", 15)
+ poll_timeout_ups = full_orchestrator_payload.get("poll_timeout_upscale", full_orchestrator_payload.get("poll_timeout", 30 * 60) * 2)
+
+ print(f"[STITCH UPSCALE] Polling for completion (timeout: {poll_timeout_ups}s, interval: {poll_interval_ups}s)")
+
+ upscaled_video_db_location = db_ops.poll_task_status(
+ task_id=upscale_sub_task_id,
+ poll_interval_seconds=poll_interval_ups,
+ timeout_seconds=poll_timeout_ups
+ )
+ print(f"[STITCH UPSCALE] Poll result: {upscaled_video_db_location}")
+ dprint(f"Stitch Task {stitch_task_id_str}: Upscale sub-task {upscale_sub_task_id} poll result: {upscaled_video_db_location}")
+
+ if upscaled_video_db_location:
+ upscaled_video_abs_path: Path
+ if db_ops.DB_TYPE == "sqlite" and db_ops.SQLITE_DB_PATH and upscaled_video_db_location.startswith("files/"):
+ sqlite_db_parent = Path(db_ops.SQLITE_DB_PATH).resolve().parent
+ upscaled_video_abs_path = sqlite_db_parent / "public" / upscaled_video_db_location
+ else:
+ upscaled_video_abs_path = Path(upscaled_video_db_location)
+
+ if upscaled_video_abs_path.exists():
+ print(f"[STITCH UPSCALE] Upscale completed successfully: {upscaled_video_abs_path}")
+ dprint(f"Stitch: Upscale sub-task {upscale_sub_task_id} completed. Output: {upscaled_video_abs_path}")
+
+ # Analyze upscaled result
+ try:
+ upscaled_frame_count, upscaled_fps = sm_get_video_frame_count_and_fps(str(upscaled_video_abs_path))
+ print(f"[STITCH UPSCALE] Upscaled result: {upscaled_frame_count} frames @ {upscaled_fps} FPS")
+ dprint(f"[DEBUG] Post-upscale analysis: {upscaled_frame_count} frames, {upscaled_fps} FPS")
+
+ # Compare frame counts
+ if upscaled_frame_count != original_frames_count:
+ print(f"[STITCH UPSCALE] Frame count changed during upscale: {original_frames_count} → {upscaled_frame_count}")
+ except Exception as e_post_upscale:
+ print(f"[WARNING] Could not analyze upscaled video: {e_post_upscale}")
+
+ video_path_after_optional_upscale = upscaled_video_abs_path
+
+ if not full_orchestrator_payload.get("skip_cleanup_enabled", False) and \
+ not full_orchestrator_payload.get("debug_mode_enabled", False) and \
+ current_stitched_video_path.exists() and current_stitched_video_path != video_path_after_optional_upscale:
+ try:
+ current_stitched_video_path.unlink()
+ dprint(f"Stitch: Removed non-upscaled video {current_stitched_video_path} after successful upscale.")
+ except Exception as e_del_non_upscaled:
+ dprint(f"Stitch: Warning - could not remove non-upscaled video {current_stitched_video_path}: {e_del_non_upscaled}")
+ else:
+ print(f"[STITCH UPSCALE] ERROR: Upscale output missing at {upscaled_video_abs_path}. Using non-upscaled video.")
+ print(f"[WARNING] Stitch Task {stitch_task_id_str}: Upscale sub-task {upscale_sub_task_id} output missing ({upscaled_video_abs_path}). Using non-upscaled video.")
+ else:
+ print(f"[STITCH UPSCALE] ERROR: Upscale sub-task failed or timed out. Using non-upscaled video.")
+ print(f"[WARNING] Stitch Task {stitch_task_id_str}: Upscale sub-task {upscale_sub_task_id} failed or timed out. Using non-upscaled video.")
+
+ elif upscale_factor > 1.0 and not upscale_model_name:
+ print(f"[STITCH UPSCALE] Upscale factor {upscale_factor} > 1.0 but no upscale_model_name provided. Skipping upscale.")
+ dprint(f"Stitch: Upscale factor {upscale_factor} > 1.0 but no upscale_model_name provided. Skipping upscale.")
+ else:
+ print(f"[STITCH UPSCALE] No upscaling requested (factor: {upscale_factor})")
+ dprint(f"Stitch: No upscaling (factor: {upscale_factor})")
+
+ # Use prepare_output_path_with_upload to handle final video location consistently (with Supabase upload support)
+ final_video_filename = f"{orchestrator_run_id}_final{video_path_after_optional_upscale.suffix}"
+ if upscale_factor > 1.0:
+ final_video_filename = f"{orchestrator_run_id}_final_upscaled_{upscale_factor:.1f}x{video_path_after_optional_upscale.suffix}"
+
+ final_video_path, initial_db_location = prepare_output_path_with_upload(
+ task_id=stitch_task_id_str,
+ filename=final_video_filename,
+ main_output_dir_base=stitch_processing_dir,
+ dprint=dprint
+ )
+
+ # Move the video to final location if it's not already there
+ if video_path_after_optional_upscale.resolve() != final_video_path.resolve():
+ dprint(f"Stitch Task {stitch_task_id_str}: Moving {video_path_after_optional_upscale} to {final_video_path}")
+ shutil.move(str(video_path_after_optional_upscale), str(final_video_path))
+ else:
+ dprint(f"Stitch Task {stitch_task_id_str}: Video already at final destination {final_video_path}")
+
+ # Handle Supabase upload (if configured) and get final location for DB
+ final_video_location_for_db = upload_and_get_final_output_location(
+ final_video_path,
+ final_video_filename, # Pass only the filename to avoid redundant subfolder
+ initial_db_location,
+ dprint=dprint
+ )
+
+ print(f"Stitch Task {stitch_task_id_str}: Final video saved to: {final_video_path} (DB location: {final_video_location_for_db})")
+
+ # Analyze final result
+ try:
+ final_frame_count, final_fps = sm_get_video_frame_count_and_fps(str(final_video_path))
+ final_duration = final_frame_count / final_fps if final_fps > 0 else 0
+ print(f"[STITCH FINAL] Final video: {final_frame_count} frames @ {final_fps} FPS = {final_duration:.2f}s")
+ print(f"[STITCH_FINAL_ANALYSIS] Complete stitching analysis:")
+ print(f"[STITCH_FINAL_ANALYSIS] Input segments: {len(segment_video_paths_for_stitch)}")
+ print(f"[STITCH_FINAL_ANALYSIS] Overlap settings: {expanded_frame_overlaps}")
+ print(f"[STITCH_FINAL_ANALYSIS] Expected final frames: {expected_final_length if 'expected_final_length' in locals() else 'Not calculated'}")
+ print(f"[STITCH_FINAL_ANALYSIS] Actual final frames: {final_frame_count}")
+ if 'expected_final_length' in locals() and final_frame_count != expected_final_length:
+ print(f"[STITCH_FINAL_ANALYSIS] ⚠️ FINAL LENGTH MISMATCH! Expected {expected_final_length}, got {final_frame_count}")
+
+ # Detailed analysis of the final video
+ debug_video_analysis(final_video_path, "FINAL_STITCHED_VIDEO", stitch_task_id_str)
+
+ dprint(f"[DEBUG] Final video analysis: {final_frame_count} frames, {final_fps} FPS, {final_duration:.2f}s duration")
+ except Exception as e_final_analysis:
+ print(f"[WARNING] Could not analyze final video: {e_final_analysis}")
+
+ # Note: Individual segments already have banner overlays applied when show_input_images is enabled,
+ # so the stitched video will automatically include them. No additional overlay needed here.
+
+ stitch_success = True
+
+ # Note: The orchestrator will be marked as complete by the Edge Function
+ # when it processes the stitch task upload. This ensures atomic completion
+ # with the final video upload.
+ print(f"[ORCHESTRATOR_COMPLETION_DEBUG] Stitch task complete. Orchestrator {orchestrator_task_id_ref} will be marked complete by Edge Function.")
+ dprint(f"Stitch: Task complete. Orchestrator completion will be handled by Edge Function.")
+
+ # Return the final video path so the stitch task itself gets uploaded via Edge Function
+ return stitch_success, str(final_video_path.resolve())
+
+ except Exception as e:
+ print(f"[ERROR] Stitch task {stitch_task_id_str}: Unexpected error during stitching: {e}")
+ traceback.print_exc()
+
+ # Notify orchestrator of stitch failure
+ if 'orchestrator_task_id_ref' in locals() and orchestrator_task_id_ref:
+ try:
+ error_msg = f"Stitch task failed: {str(e)[:200]}"
+ db_ops.update_task_status(
+ orchestrator_task_id_ref,
+ db_ops.STATUS_FAILED,
+ error_msg
+ )
+ dprint(f"Stitch: Marked orchestrator task {orchestrator_task_id_ref} as FAILED due to exception")
+ except Exception as e_orch:
+ dprint(f"Stitch: Warning - could not update orchestrator status: {e_orch}")
+
+ return False, f"Stitch task failed: {str(e)[:200]}"
diff --git a/source/specialized_handlers.py b/source/specialized_handlers.py
new file mode 100644
index 000000000..77d7a3474
--- /dev/null
+++ b/source/specialized_handlers.py
@@ -0,0 +1,423 @@
+"""Specialized task handlers for headless.py."""
+
+import traceback
+import tempfile
+import shutil
+from pathlib import Path
+import numpy as np
+from PIL import Image
+import cv2
+
+# Add the parent directory to Python path to allow Wan2GP module import
+import sys
+wan2gp_path = Path(__file__).resolve().parent.parent / "Wan2GP"
+if str(wan2gp_path) not in sys.path:
+ sys.path.insert(0, str(wan2gp_path))
+
+try:
+ from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
+except ImportError:
+ PoseBodyFaceVideoAnnotator = None
+
+try:
+ from preprocessing.midas.depth import DepthAnnotator
+except ImportError:
+ DepthAnnotator = None
+
+from . import db_operations as db_ops
+from .common_utils import sm_get_unique_target_path, parse_resolution as sm_parse_resolution, prepare_output_path, save_frame_from_video, report_orchestrator_failure, prepare_output_path_with_upload, upload_and_get_final_output_location
+from .video_utils import rife_interpolate_images_to_video as sm_rife_interpolate_images_to_video
+
+def handle_generate_openpose_task(task_params_dict: dict, main_output_dir_base: Path, task_id: str, dprint: callable):
+ """Handles the 'generate_openpose' task."""
+ print(f"[Task ID: {task_id}] Handling 'generate_openpose' task.")
+ input_image_path_str = task_params_dict.get("input_image_path")
+ input_image_task_id = task_params_dict.get("input_image_task_id")
+ custom_output_dir = task_params_dict.get("output_dir")
+
+ if PoseBodyFaceVideoAnnotator is None:
+ msg = "PoseBodyFaceVideoAnnotator not imported. Cannot process 'generate_openpose' task."
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, "PoseBodyFaceVideoAnnotator module not available."
+
+ # If direct path is not given, try to resolve it from a dependency task ID
+ if not input_image_path_str:
+ dprint(f"Task {task_id}: 'input_image_path' not found, trying 'input_image_task_id': {input_image_task_id}")
+ if not input_image_task_id:
+ msg = "Task requires either 'input_image_path' or 'input_image_task_id'."
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ try:
+ path_from_db = db_ops.get_task_output_location_from_db(input_image_task_id)
+ if not path_from_db:
+ msg = f"Task {task_id}: Could not find output location for dependency task {input_image_task_id}."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ abs_path = db_ops.get_abs_path_from_db_path(path_from_db, dprint)
+ if not abs_path:
+ msg = f"Task {task_id}: Could not resolve or find image file from DB path '{path_from_db}'."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ input_image_path_str = str(abs_path)
+ dprint(f"Task {task_id}: Resolved input image path from task ID to '{input_image_path_str}'")
+ except Exception as e:
+ error_msg = f"Task {task_id}: Failed during input image path resolution from task ID: {e}"
+ print(f"[ERROR] {error_msg}")
+ traceback.print_exc()
+ report_orchestrator_failure(task_params_dict, error_msg, dprint)
+ return False, str(e)
+
+ input_image_path = Path(input_image_path_str)
+
+ if not input_image_path.is_file():
+ print(f"[ERROR Task ID: {task_id}] Input image file not found: {input_image_path}")
+ msg = f"Input image not found: {input_image_path}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ final_save_path, initial_db_location = prepare_output_path_with_upload(
+ task_id,
+ f"{task_id}_openpose.png",
+ main_output_dir_base,
+ dprint=dprint,
+ custom_output_dir=custom_output_dir
+ )
+
+ try:
+ pil_input_image = Image.open(input_image_path).convert("RGB")
+
+ pose_cfg_dict = {
+ "DETECTION_MODEL": "ckpts/pose/yolox_l.onnx",
+ "POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
+ "RESIZE_SIZE": 1024
+ }
+ if PoseBodyFaceVideoAnnotator is None:
+ raise ImportError("PoseBodyFaceVideoAnnotator could not be imported.")
+
+ pose_annotator = PoseBodyFaceVideoAnnotator(pose_cfg_dict)
+
+ openpose_np_frames_bgr = pose_annotator.forward([pil_input_image])
+
+ if not openpose_np_frames_bgr or openpose_np_frames_bgr[0] is None:
+ print(f"[ERROR Task ID: {task_id}] OpenPose generation failed or returned no frame.")
+ msg = "OpenPose generation returned no data."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ openpose_np_frame_bgr = openpose_np_frames_bgr[0]
+
+ openpose_pil_image = Image.fromarray(openpose_np_frame_bgr.astype(np.uint8))
+ openpose_pil_image.save(final_save_path)
+
+ # Upload to Supabase if configured
+ final_db_location = upload_and_get_final_output_location(
+ final_save_path, task_id, initial_db_location, dprint=dprint
+ )
+
+ print(f"[Task ID: {task_id}] Successfully generated OpenPose image to: {final_save_path.resolve()}")
+ return True, final_db_location
+
+ except ImportError as ie:
+ print(f"[ERROR Task ID: {task_id}] Import error during OpenPose generation: {ie}. Ensure 'preprocessing' module is in PYTHONPATH and dependencies are installed.")
+ traceback.print_exc()
+ msg = f"Import error: {ie}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ except FileNotFoundError as fnfe:
+ print(f"[ERROR Task ID: {task_id}] ONNX model file not found for OpenPose: {fnfe}. Ensure 'ckpts/pose/*' models are present.")
+ traceback.print_exc()
+ msg = f"ONNX model not found: {fnfe}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ except Exception as e:
+ print(f"[ERROR Task ID: {task_id}] Failed during OpenPose image generation: {e}")
+ traceback.print_exc()
+ msg = f"OpenPose generation exception: {e}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+def handle_extract_frame_task(task_params_dict: dict, main_output_dir_base: Path, task_id: str, dprint: callable):
+ """Handles the 'extract_frame' task."""
+ print(f"[Task ID: {task_id}] Handling 'extract_frame' task.")
+
+ input_video_task_id = task_params_dict.get("input_video_task_id")
+ frame_index = task_params_dict.get("frame_index", 0) # Default to first frame
+ custom_output_dir = task_params_dict.get("output_dir")
+
+ if not input_video_task_id:
+ msg = f"Task {task_id}: Missing 'input_video_task_id' in payload."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ try:
+ # Get the output path of the dependency task
+ # Note: This is looking up the direct task output, not a dependency relationship
+ video_path_from_db = db_ops.get_task_output_location_from_db(input_video_task_id)
+ if not video_path_from_db:
+ msg = f"Task {task_id}: Could not find output location for dependency task {input_video_task_id}."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ video_abs_path = db_ops.get_abs_path_from_db_path(video_path_from_db, dprint)
+ if not video_abs_path:
+ msg = f"Task {task_id}: Could not resolve or find video file from DB path '{video_path_from_db}'."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ # Use prepare_output_path_with_upload to determine the correct save location
+ output_filename = f"{task_id}_frame_{frame_index}.png"
+ final_save_path, initial_db_location = prepare_output_path_with_upload(
+ task_id,
+ output_filename,
+ main_output_dir_base,
+ dprint=dprint,
+ custom_output_dir=custom_output_dir
+ )
+
+ # The resolution for save_frame_from_video can be inferred from the video itself
+ # Or passed in the payload if a specific resize is needed. For now, we don't resize.
+ cap = cv2.VideoCapture(str(video_abs_path))
+ if not cap.isOpened():
+ msg = f"Task {task_id}: Could not open video file {video_abs_path}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ cap.release()
+
+ # Now use the save_frame_from_video utility
+ success = save_frame_from_video(
+ video_path=video_abs_path,
+ frame_index=frame_index,
+ output_image_path=final_save_path,
+ resolution=(width, height) # Use native resolution
+ )
+
+ if success:
+ # Upload to Supabase if configured
+ final_db_location = upload_and_get_final_output_location(
+ final_save_path, task_id, initial_db_location, dprint=dprint
+ )
+
+ print(f"[Task ID: {task_id}] Successfully extracted frame {frame_index} to: {final_save_path}")
+ return True, final_db_location
+ else:
+ msg = f"Task {task_id}: save_frame_from_video utility failed for video {video_abs_path}."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ except Exception as e:
+ error_msg = f"Task {task_id}: Failed during frame extraction: {e}"
+ print(f"[ERROR] {error_msg}")
+ traceback.print_exc()
+ report_orchestrator_failure(task_params_dict, error_msg, dprint)
+ return False, str(e)
+
+def handle_rife_interpolate_task(wgp_mod, task_params_dict: dict, main_output_dir_base: Path, task_id: str, dprint: callable):
+ """Handles the 'rife_interpolate_images' task."""
+ print(f"[Task ID: {task_id}] Handling 'rife_interpolate_images' task.")
+
+ input_image_path1_str = task_params_dict.get("input_image_path1")
+ input_image_path2_str = task_params_dict.get("input_image_path2")
+ output_video_path_str = task_params_dict.get("output_path")
+ num_rife_frames = task_params_dict.get("frames")
+ resolution_str = task_params_dict.get("resolution")
+ custom_output_dir = task_params_dict.get("output_dir")
+
+ required_params = {
+ "input_image_path1": input_image_path1_str,
+ "input_image_path2": input_image_path2_str,
+ "output_path": output_video_path_str,
+ "frames": num_rife_frames,
+ "resolution": resolution_str
+ }
+ missing_params = [key for key, value in required_params.items() if value is None]
+ if missing_params:
+ error_msg = f"Missing required parameters for rife_interpolate_images: {', '.join(missing_params)}"
+ print(f"[ERROR Task ID: {task_id}] {error_msg}")
+ return False, error_msg
+
+ input_image1_path = Path(input_image_path1_str)
+ input_image2_path = Path(input_image_path2_str)
+ output_video_path = Path(output_video_path_str)
+ output_video_path.parent.mkdir(parents=True, exist_ok=True)
+
+ generation_success = False
+ output_location_to_db = None
+
+ final_save_path_for_video, initial_db_location_for_rife = prepare_output_path_with_upload(
+ task_id,
+ f"{task_id}_rife_interpolated.mp4",
+ main_output_dir_base,
+ dprint=dprint,
+ custom_output_dir=custom_output_dir
+ )
+ output_video_path = final_save_path_for_video
+ output_video_path.parent.mkdir(parents=True, exist_ok=True)
+
+ dprint(f"[Task ID: {task_id}] Checking input image paths.")
+ if not input_image1_path.is_file():
+ print(f"[ERROR Task ID: {task_id}] Input image 1 not found: {input_image1_path}")
+ return False, f"Input image 1 not found: {input_image1_path}"
+ if not input_image2_path.is_file():
+ print(f"[ERROR Task ID: {task_id}] Input image 2 not found: {input_image2_path}")
+ return False, f"Input image 2 not found: {input_image2_path}"
+ dprint(f"[Task ID: {task_id}] Input images found.")
+
+ temp_output_dir = tempfile.mkdtemp(prefix=f"wgp_rife_{task_id}_")
+ original_wgp_save_path = wgp_mod.save_path
+ wgp_mod.save_path = str(temp_output_dir)
+
+ try:
+ pil_image_start = Image.open(input_image1_path).convert("RGB")
+ pil_image_end = Image.open(input_image2_path).convert("RGB")
+
+ print(f"[Task ID: {task_id}] Starting RIFE interpolation via video_utils.")
+ dprint(f" Input 1: {input_image1_path}")
+ dprint(f" Input 2: {input_image2_path}")
+
+ rife_success = sm_rife_interpolate_images_to_video(
+ image1=pil_image_start,
+ image2=pil_image_end,
+ num_frames=int(num_rife_frames),
+ resolution_wh=sm_parse_resolution(resolution_str),
+ output_path=final_save_path_for_video,
+ fps=16,
+ dprint_func=lambda msg: dprint(f"[Task ID: {task_id}] (rife_util) {msg}")
+ )
+
+ if rife_success:
+ if final_save_path_for_video.exists() and final_save_path_for_video.stat().st_size > 0:
+ generation_success = True
+ # Upload to Supabase if configured
+ output_location_to_db = upload_and_get_final_output_location(
+ final_save_path_for_video, task_id, initial_db_location_for_rife, dprint=dprint
+ )
+ print(f"[Task ID: {task_id}] RIFE video saved to: {final_save_path_for_video.resolve()} (DB: {output_location_to_db})")
+ else:
+ print(f"[ERROR Task ID: {task_id}] RIFE utility reported success, but output file is missing or empty: {final_save_path_for_video}")
+ generation_success = False
+ else:
+ print(f"[ERROR Task ID: {task_id}] RIFE interpolation using video_utils failed.")
+ generation_success = False
+
+ except Exception as e:
+ print(f"[ERROR Task ID: {task_id}] Overall _handle_rife_interpolate_task failed: {e}")
+ traceback.print_exc()
+ generation_success = False
+ finally:
+ wgp_mod.save_path = original_wgp_save_path
+
+ try:
+ shutil.rmtree(temp_output_dir)
+ dprint(f"[Task ID: {task_id}] Cleaned up temporary directory: {temp_output_dir}")
+ except Exception as e_clean:
+ print(f"[WARNING Task ID: {task_id}] Failed to clean up temporary directory {temp_output_dir}: {e_clean}")
+
+
+ return generation_success, output_location_to_db
+
+def handle_generate_depth_task(task_params_dict: dict, main_output_dir_base: Path, task_id: str, dprint: callable):
+ """Handles the 'generate_depth' task – produces a MiDaS depth‐map PNG from an image."""
+ print(f"[Task ID: {task_id}] Handling 'generate_depth' task.")
+
+ input_image_path_str = task_params_dict.get("input_image_path")
+ input_image_task_id = task_params_dict.get("input_image_task_id")
+ custom_output_dir = task_params_dict.get("output_dir")
+
+ if DepthAnnotator is None:
+ msg = "DepthAnnotator not available. Ensure MiDaS preprocessing module is installed."
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ # Resolve image path from dependency task if direct path is missing
+ if not input_image_path_str:
+ dprint(f"Task {task_id}: 'input_image_path' missing, trying 'input_image_task_id': {input_image_task_id}")
+ if not input_image_task_id:
+ msg = "Task requires either 'input_image_path' or 'input_image_task_id'."
+ print(f"[ERROR Task ID: {task_id}] {msg}")
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ try:
+ # Get the output path of the dependency task
+ # Note: This is looking up the direct task output, not a dependency relationship
+ path_from_db = db_ops.get_task_output_location_from_db(input_image_task_id)
+ if not path_from_db:
+ msg = f"Task {task_id}: Could not find output location for dependency task {input_image_task_id}."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ abs_path = db_ops.get_abs_path_from_db_path(path_from_db, dprint)
+ if not abs_path:
+ msg = f"Task {task_id}: Could not resolve or find image file from DB path '{path_from_db}'."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ input_image_path_str = str(abs_path)
+ dprint(f"Task {task_id}: Resolved input image path from task ID to '{input_image_path_str}'")
+ except Exception as e:
+ error_msg = f"Task {task_id}: Failed during input image path resolution: {e}"
+ print(f"[ERROR] {error_msg}")
+ traceback.print_exc()
+ report_orchestrator_failure(task_params_dict, error_msg, dprint)
+ return False, str(e)
+
+ input_image_path = Path(input_image_path_str)
+ if not input_image_path.is_file():
+ print(f"[ERROR Task ID: {task_id}] Input image file not found: {input_image_path}")
+ msg = f"Input image not found: {input_image_path}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ # Prepare save path for depth PNG
+ final_save_path, initial_db_location = prepare_output_path_with_upload(
+ task_id,
+ f"{task_id}_depth.png",
+ main_output_dir_base,
+ dprint=dprint,
+ custom_output_dir=custom_output_dir,
+ )
+
+ try:
+ pil_input_image = Image.open(input_image_path).convert("RGB")
+ depth_cfg = {"PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt"}
+ depth_annotator = DepthAnnotator(depth_cfg)
+
+ depth_np = depth_annotator.forward(pil_input_image) # Returns RGB depth map as numpy array
+ if depth_np is None:
+ msg = "Depth generation returned no data."
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+
+ depth_pil = Image.fromarray(depth_np.astype(np.uint8))
+ depth_pil.save(final_save_path)
+
+ # Upload to Supabase if configured
+ final_db_location = upload_and_get_final_output_location(
+ final_save_path, task_id, initial_db_location, dprint=dprint
+ )
+
+ print(f"[Task ID: {task_id}] Successfully generated depth map: {final_save_path.resolve()}")
+ return True, final_db_location
+
+ except ImportError as ie:
+ print(f"[ERROR Task ID: {task_id}] Import error during depth generation: {ie}")
+ traceback.print_exc()
+ msg = f"Import error: {ie}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ except FileNotFoundError as fnfe:
+ print(f"[ERROR Task ID: {task_id}] Model file not found for MiDaS depth: {fnfe}")
+ traceback.print_exc()
+ msg = f"Model not found: {fnfe}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
+ except Exception as e:
+ print(f"[ERROR Task ID: {task_id}] Failed during depth map generation: {e}")
+ traceback.print_exc()
+ msg = f"Depth generation exception: {e}"
+ report_orchestrator_failure(task_params_dict, msg, dprint)
+ return False, msg
\ No newline at end of file
diff --git a/source/video_utils.py b/source/video_utils.py
new file mode 100644
index 000000000..44c04ed12
--- /dev/null
+++ b/source/video_utils.py
@@ -0,0 +1,1196 @@
+import math
+import subprocess
+from pathlib import Path
+import traceback
+import os
+import sys
+import json
+import time
+
+try:
+ import cv2 # pip install opencv-python
+ import numpy as np
+ from PIL import Image
+ import torch
+ try:
+ import moviepy.editor as mpe # pip install moviepy
+ _MOVIEPY_AVAILABLE = True
+ except ImportError:
+ _MOVIEPY_AVAILABLE = False
+ _COLOR_MATCH_DEPS_AVAILABLE = True
+
+ # Add project root to path to allow absolute imports from source
+ project_root = Path(__file__).resolve().parent.parent
+ if str(project_root) not in sys.path:
+ sys.path.insert(0, str(project_root))
+
+ # Now that root is in path, we can import from Wan2GP and source
+ from Wan2GP.rife.inference import temporal_interpolation
+ from source.common_utils import (
+ dprint, get_video_frame_count_and_fps,
+ download_image_if_url, sm_get_unique_target_path,
+ _apply_strength_to_image as sm_apply_strength_to_image,
+ create_color_frame as sm_create_color_frame,
+ image_to_frame as sm_image_to_frame,
+ _adjust_frame_brightness as sm_adjust_frame_brightness,
+ get_easing_function as sm_get_easing_function,
+ wait_for_file_stable as sm_wait_for_file_stable
+ )
+except ImportError as e_import:
+ print(f"Critical import error in video_utils.py: {e_import}")
+ traceback.print_exc()
+ _COLOR_MATCH_DEPS_AVAILABLE = False
+
+def crossfade_ease(alpha_lin: float) -> float:
+ """Cosine ease-in-out function (maps 0..1 to 0..1).
+ Used to determine the blending alpha for crossfades.
+ """
+ return (1 - math.cos(alpha_lin * math.pi)) / 2.0
+
+def _blend_linear(a: np.ndarray, b: np.ndarray, t: float) -> np.ndarray:
+ return cv2.addWeighted(a, 1.0-t, b, t, 0)
+
+def _blend_linear_sharp(a: np.ndarray, b: np.ndarray, t: float, amt: float) -> np.ndarray:
+ base = _blend_linear(a,b,t)
+ if amt<=0: return base
+ blur = cv2.GaussianBlur(base,(0,0),3)
+ return cv2.addWeighted(base, 1.0+amt*t, blur, -amt*t, 0)
+
+def cross_fade_overlap_frames(
+ segment1_frames: list[np.ndarray],
+ segment2_frames: list[np.ndarray],
+ overlap_count: int,
+ mode: str = "linear_sharp",
+ sharp_amt: float = 0.3
+) -> list[np.ndarray]:
+ """
+ Cross-fades the overlapping frames between two segments using various modes.
+
+ Args:
+ segment1_frames: Frames from the first segment (video ending)
+ segment2_frames: Frames from the second segment (video starting)
+ overlap_count: Number of frames to cross-fade
+ mode: Blending mode ("linear", "linear_sharp")
+ sharp_amt: Sharpening amount for "linear_sharp" mode (0-1)
+
+ Returns:
+ List of cross-faded frames for the overlap region
+ """
+ if overlap_count <= 0:
+ return []
+
+ n = min(overlap_count, len(segment1_frames), len(segment2_frames))
+ if n <= 0:
+ return []
+
+ # Determine target resolution from segment2 (the newer generated video)
+ if not segment2_frames:
+ return []
+
+ target_height, target_width = segment2_frames[0].shape[:2]
+ target_resolution = (target_width, target_height)
+
+ # Log dimension information for debugging
+ seg1_height, seg1_width = segment1_frames[0].shape[:2] if segment1_frames else (0, 0)
+ dprint(f"CrossFade: Segment1 resolution: {seg1_width}x{seg1_height}, Segment2 resolution: {target_width}x{target_height}")
+ dprint(f"CrossFade: Target resolution set to: {target_width}x{target_height} (from segment2)")
+ dprint(f"CrossFade: Processing {n} overlap frames")
+
+ out_frames = []
+ for i in range(n):
+ t_linear = (i + 1) / float(n)
+ alpha = crossfade_ease(t_linear)
+
+ frame_a_np = segment1_frames[-n+i].astype(np.float32)
+ frame_b_np = segment2_frames[i].astype(np.float32)
+
+ # Log original shapes before any resizing
+ original_a_shape = frame_a_np.shape[:2]
+ original_b_shape = frame_b_np.shape[:2]
+
+ # Ensure both frames have the same dimensions
+ if frame_a_np.shape[:2] != (target_height, target_width):
+ frame_a_np = cv2.resize(frame_a_np, target_resolution, interpolation=cv2.INTER_AREA).astype(np.float32)
+ if i == 0: # Log only for first frame to avoid spam
+ dprint(f"CrossFade: Resized segment1 frame from {original_a_shape} to {frame_a_np.shape[:2]}")
+ if frame_b_np.shape[:2] != (target_height, target_width):
+ frame_b_np = cv2.resize(frame_b_np, target_resolution, interpolation=cv2.INTER_AREA).astype(np.float32)
+ if i == 0: # Log only for first frame to avoid spam
+ dprint(f"CrossFade: Resized segment2 frame from {original_b_shape} to {frame_b_np.shape[:2]}")
+
+ blended_float: np.ndarray
+ if mode == "linear_sharp":
+ blended_float = _blend_linear_sharp(frame_a_np, frame_b_np, alpha, sharp_amt)
+ elif mode == "linear":
+ blended_float = _blend_linear(frame_a_np, frame_b_np, alpha)
+ else:
+ dprint(f"Warning: Unknown crossfade mode '{mode}'. Defaulting to linear.")
+ blended_float = _blend_linear(frame_a_np, frame_b_np, alpha)
+
+ blended_uint8 = np.clip(blended_float, 0, 255).astype(np.uint8)
+ out_frames.append(blended_uint8)
+
+ return out_frames
+
+def extract_frames_from_video(video_path: str | Path, start_frame: int = 0, num_frames: int = None, *, dprint_func=print) -> list[np.ndarray]:
+ """
+ Extracts frames from a video file as numpy arrays.
+
+ Args:
+ video_path: Path to the video file
+ start_frame: Starting frame index (0-based)
+ num_frames: Number of frames to extract (None = all remaining frames)
+ dprint_func: The function to use for printing debug messages
+
+ Returns:
+ List of frames as BGR numpy arrays
+ """
+ frames = []
+ cap = cv2.VideoCapture(str(video_path))
+
+ if not cap.isOpened():
+ dprint_func(f"Error: Could not open video {video_path}")
+ return frames
+
+ total_frames_video = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ cap.set(cv2.CAP_PROP_POS_FRAMES, float(start_frame))
+
+ frames_to_read = num_frames if num_frames is not None else (total_frames_video - start_frame)
+ frames_to_read = min(frames_to_read, total_frames_video - start_frame)
+
+ for i in range(frames_to_read):
+ ret, frame = cap.read()
+ if not ret:
+ dprint_func(f"Warning: Could not read frame {start_frame + i} from {video_path}")
+ break
+ frames.append(frame)
+
+ cap.release()
+ return frames
+
+def create_video_from_frames_list(
+ frames_list: list[np.ndarray],
+ output_path: str | Path,
+ fps: int,
+ resolution: tuple[int, int]
+) -> Path | None:
+ """Creates a video from a list of NumPy BGR frames using FFmpeg subprocess.
+ Returns the Path object of the successfully written file, or None if failed.
+ """
+ output_path_obj = Path(output_path)
+ output_path_mp4 = output_path_obj.with_suffix('.mp4')
+ output_path_mp4.parent.mkdir(parents=True, exist_ok=True)
+
+ ffmpeg_cmd = [
+ "ffmpeg", "-y",
+ "-loglevel", "error",
+ "-f", "rawvideo",
+ "-vcodec", "rawvideo",
+ "-pix_fmt", "bgr24",
+ "-s", f"{resolution[0]}x{resolution[1]}",
+ "-r", str(fps),
+ "-i", "-",
+ "-c:v", "libx264",
+ "-pix_fmt", "yuv420p",
+ "-preset", "veryfast",
+ "-crf", "23",
+ str(output_path_mp4.resolve())
+ ]
+
+ processed_frames = []
+ for frame_idx, frame_np in enumerate(frames_list):
+ if frame_np is None or not isinstance(frame_np, np.ndarray):
+ continue
+ if frame_np.dtype != np.uint8:
+ frame_np = frame_np.astype(np.uint8)
+ if frame_np.shape[0] != resolution[1] or frame_np.shape[1] != resolution[0] or frame_np.shape[2] != 3:
+ try:
+ frame_np = cv2.resize(frame_np, resolution, interpolation=cv2.INTER_AREA)
+ except Exception:
+ continue
+ processed_frames.append(frame_np)
+
+ if not processed_frames:
+ return None
+
+ try:
+ raw_video_data = b''.join(frame.tobytes() for frame in processed_frames)
+ except Exception:
+ return None
+
+ if not raw_video_data:
+ return None
+
+ try:
+ proc = subprocess.run(
+ ffmpeg_cmd,
+ input=raw_video_data,
+ capture_output=True,
+ timeout=60
+ )
+
+ if proc.returncode == 0:
+ if output_path_mp4.exists() and output_path_mp4.stat().st_size > 0:
+ return output_path_mp4
+ return None
+ else:
+ if output_path_mp4.exists():
+ try:
+ output_path_mp4.unlink()
+ except Exception:
+ pass
+ return None
+
+ except subprocess.TimeoutExpired:
+ return None
+ except FileNotFoundError:
+ return None
+ except Exception:
+ return None
+
+def _apply_saturation_to_video_ffmpeg(
+ input_video_path: str | Path,
+ output_video_path: str | Path,
+ saturation_level: float,
+ preset: str = "veryfast"
+) -> bool:
+ """Applies a saturation adjustment to the full video using FFmpeg's eq filter.
+ Returns: True if FFmpeg succeeds and the output file exists & is non-empty, else False.
+ """
+ inp = Path(input_video_path)
+ outp = Path(output_video_path)
+ outp.parent.mkdir(parents=True, exist_ok=True)
+
+ cmd = [
+ "ffmpeg", "-y",
+ "-i", str(inp.resolve()),
+ "-vf", f"eq=saturation={saturation_level}",
+ "-c:v", "libx264",
+ "-preset", preset,
+ "-pix_fmt", "yuv420p",
+ "-an",
+ str(outp.resolve())
+ ]
+
+ try:
+ subprocess.run(cmd, check=True, capture_output=True, text=True, encoding="utf-8")
+ if outp.exists() and outp.stat().st_size > 0:
+ return True
+ return False
+ except subprocess.CalledProcessError:
+ return False
+
+def color_match_video_to_reference(source_video_path: str | Path, reference_video_path: str | Path,
+ output_video_path: str | Path, parsed_resolution: tuple[int, int]) -> bool:
+ """
+ Color matches source_video to reference_video using histogram matching on the last frame of reference
+ and first frame of source. Applies the transformation to all frames of source_video.
+ Returns True if successful, False otherwise.
+ """
+ try:
+ ref_cap = cv2.VideoCapture(str(reference_video_path))
+ if not ref_cap.isOpened():
+ return False
+
+ ref_frame_count_from_cap, _ = get_video_frame_count_and_fps(str(reference_video_path))
+ if ref_frame_count_from_cap is None: ref_frame_count_from_cap = int(ref_cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ ref_cap.set(cv2.CAP_PROP_POS_FRAMES, float(max(0, ref_frame_count_from_cap - 1)))
+ ret, ref_frame = ref_cap.read()
+ ref_cap.release()
+
+ if not ret or ref_frame is None:
+ return False
+
+ if ref_frame.shape[1] != parsed_resolution[0] or ref_frame.shape[0] != parsed_resolution[1]:
+ ref_frame = cv2.resize(ref_frame, parsed_resolution, interpolation=cv2.INTER_AREA)
+
+ src_cap = cv2.VideoCapture(str(source_video_path))
+ if not src_cap.isOpened():
+ return False
+
+ src_fps_from_cap = src_cap.get(cv2.CAP_PROP_FPS)
+
+ ret, src_first_frame = src_cap.read()
+ if not ret or src_first_frame is None:
+ src_cap.release()
+ return False
+
+ if src_first_frame.shape[1] != parsed_resolution[0] or src_first_frame.shape[0] != parsed_resolution[1]:
+ src_first_frame = cv2.resize(src_first_frame, parsed_resolution, interpolation=cv2.INTER_AREA)
+
+ def match_histogram_channel(source_channel, reference_channel):
+ src_hist, _ = np.histogram(source_channel.flatten(), 256, [0, 256])
+ ref_hist, _ = np.histogram(reference_channel.flatten(), 256, [0, 256])
+ src_cdf = src_hist.cumsum()
+ ref_cdf = ref_hist.cumsum()
+ src_cdf = src_cdf / src_cdf[-1]
+ ref_cdf = ref_cdf / ref_cdf[-1]
+ lookup_table = np.zeros(256, dtype=np.uint8)
+ for i in range(256):
+ closest_idx = np.argmin(np.abs(ref_cdf - src_cdf[i]))
+ lookup_table[i] = closest_idx
+ return lookup_table
+
+ lookup_tables = []
+ for channel in range(3):
+ lut = match_histogram_channel(src_first_frame[:, :, channel], ref_frame[:, :, channel])
+ lookup_tables.append(lut)
+
+ output_path_obj = Path(output_video_path)
+ output_path_obj.parent.mkdir(parents=True, exist_ok=True)
+
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ out = cv2.VideoWriter(str(output_path_obj), fourcc, float(src_fps_from_cap), parsed_resolution)
+ if not out.isOpened():
+ src_cap.release()
+ return False
+
+ src_cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
+ while True:
+ ret, frame = src_cap.read()
+ if not ret: break
+ if frame.shape[1] != parsed_resolution[0] or frame.shape[0] != parsed_resolution[1]:
+ frame = cv2.resize(frame, parsed_resolution, interpolation=cv2.INTER_AREA)
+ matched_frame = frame.copy()
+ for channel in range(3):
+ matched_frame[:, :, channel] = cv2.LUT(frame[:, :, channel], lookup_tables[channel])
+ out.write(matched_frame)
+
+ src_cap.release()
+ out.release()
+
+ return True
+
+ except Exception:
+ return False
+
+# ## Video Brightness Adjustment Functions
+
+def get_video_frame_count_and_fps(video_path: str) -> tuple[int, float] | tuple[None, None]:
+ """
+ Get frame count and FPS from a video file using OpenCV.
+
+ Note: OpenCV can sometimes return incorrect frame counts for recently encoded videos.
+ Consider using manual frame extraction for critical operations where accuracy is essential.
+ """
+
+ # Try multiple times in case the video metadata is still being written
+ max_attempts = 3
+ for attempt in range(max_attempts):
+ cap = cv2.VideoCapture(str(video_path))
+ if not cap.isOpened():
+ if attempt < max_attempts - 1:
+ time.sleep(0.5) # Wait a bit before retrying
+ continue
+ return None, None
+
+ frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ cap.release()
+
+ # If we got valid values, return them
+ if frames > 0 and fps > 0:
+ return frames, fps
+
+ # Otherwise, wait and retry
+ if attempt < max_attempts - 1:
+ time.sleep(0.5)
+
+ # If all attempts failed, return what we got (might be 0 or invalid)
+ return frames, fps
+
+def adjust_frame_brightness(frame: np.ndarray, brightness_adjust: float) -> np.ndarray:
+ if brightness_adjust == 0:
+ return frame
+ factor = 1 + brightness_adjust
+ adjusted = np.clip(frame.astype(np.float32) * factor, 0, 255).astype(np.uint8)
+ return adjusted
+
+def apply_brightness_to_video_frames(input_video_path: str, output_video_path: Path, brightness_adjust: float, task_id_for_logging: str) -> Path | None:
+ """
+ Applies brightness adjustment to a video by processing its frames.
+ A brightness_adjust of 0 means no change. Negative values darken, positive values brighten.
+ """
+ try:
+ print(f"Task {task_id_for_logging}: Applying brightness adjustment {brightness_adjust} to {input_video_path}")
+
+ total_frames, fps = get_video_frame_count_and_fps(input_video_path)
+ if total_frames is None or fps is None or total_frames == 0:
+ print(f"[ERROR] Task {task_id_for_logging}: Could not get frame count or fps for {input_video_path}, or video has 0 frames.")
+ return None
+
+ frames = extract_frames_from_video(input_video_path)
+ if frames is None:
+ print(f"[ERROR] Task {task_id_for_logging}: Could not extract frames from {input_video_path}")
+ return None
+
+ adjusted_frames = []
+ first_frame = None
+ for frame in frames:
+ if first_frame is None:
+ first_frame = frame
+ adjusted_frame = adjust_frame_brightness(frame, brightness_adjust)
+ adjusted_frames.append(adjusted_frame)
+
+ if not adjusted_frames or first_frame is None:
+ print(f"[ERROR] Task {task_id_for_logging}: No frames to write for brightness-adjusted video.")
+ return None
+
+ h, w, _ = first_frame.shape
+ resolution = (w, h)
+
+ created_video_path = create_video_from_frames_list(adjusted_frames, output_video_path, fps, resolution)
+ if created_video_path and created_video_path.exists():
+ print(f"Task {task_id_for_logging}: Successfully created brightness-adjusted video at {created_video_path}")
+ return created_video_path
+ else:
+ print(f"[ERROR] Task {task_id_for_logging}: Failed to create brightness-adjusted video.")
+ return None
+ except Exception as e:
+ print(f"[ERROR] Task {task_id_for_logging}: Exception in apply_brightness_to_video_frames: {e}")
+ traceback.print_exc()
+ return None
+
+def rife_interpolate_images_to_video(
+ image1: Image.Image,
+ image2: Image.Image,
+ num_frames: int,
+ resolution_wh: tuple[int, int],
+ output_path: str | Path,
+ fps: int = 16,
+ dprint_func=print
+) -> bool:
+ """
+ Interpolates between two PIL images using RIFE to generate a video.
+ """
+ try:
+ dprint_func("Imported RIFE modules for interpolation.")
+
+ width_out, height_out = resolution_wh
+ dprint_func(f"Parsed resolution: {width_out}x{height_out}")
+
+ def pil_to_tensor_rgb_norm(pil_im: Image.Image):
+ pil_resized = pil_im.resize((width_out, height_out), Image.Resampling.LANCZOS)
+ np_rgb = np.asarray(pil_resized).astype(np.float32) / 127.5 - 1.0 # [0,255]->[-1,1]
+ tensor = torch.from_numpy(np_rgb).permute(2, 0, 1) # C H W
+ return tensor
+
+ t_start = pil_to_tensor_rgb_norm(image1)
+ t_end = pil_to_tensor_rgb_norm(image2)
+
+ sample_in = torch.stack([t_start, t_end], dim=1).unsqueeze(0) # 1 x 3 x 2 x H x W
+
+ device_for_rife = "cuda" if torch.cuda.is_available() else "cpu"
+ sample_in = sample_in.to(device_for_rife)
+ dprint_func(f"Input tensor for RIFE prepared on device: {device_for_rife}, shape: {sample_in.shape}")
+
+ exp_val = 3 # x8 (2^3 + 1 = 9 frames output by this RIFE implementation for 2 inputs)
+ flownet_ckpt = os.path.join("ckpts", "flownet.pkl")
+ dprint_func(f"Checking for RIFE model: {flownet_ckpt}")
+ if not os.path.exists(flownet_ckpt):
+ dprint_func(f"RIFE Error: flownet.pkl not found at {flownet_ckpt}")
+ return False
+ dprint_func(f"RIFE model found: {flownet_ckpt}. Exp_val: {exp_val}")
+
+ sample_in_for_rife = sample_in[0]
+
+ sample_out_from_rife = temporal_interpolation(flownet_ckpt, sample_in_for_rife, exp_val, device=device_for_rife)
+
+ if sample_out_from_rife is None:
+ dprint_func("RIFE process returned None.")
+ return False
+
+ dprint_func(f"RIFE output tensor shape: {sample_out_from_rife.shape}")
+
+ sample_out_no_batch = sample_out_from_rife.to("cpu")
+ total_frames_generated = sample_out_no_batch.shape[1]
+ dprint_func(f"RIFE produced {total_frames_generated} frames.")
+
+ if total_frames_generated < num_frames:
+ dprint_func(f"Warning: RIFE produced {total_frames_generated} frames, expected {num_frames}. Padding last frame.")
+ pad_frames = num_frames - total_frames_generated
+ else:
+ pad_frames = 0
+
+ frames_list_np = []
+ for idx in range(min(num_frames, total_frames_generated)):
+ frame_tensor = sample_out_no_batch[:, idx]
+ frame_np = ((frame_tensor.permute(1, 2, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
+ frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
+ frames_list_np.append(frame_bgr)
+
+ if pad_frames > 0 and frames_list_np:
+ last_frame_to_pad = frames_list_np[-1].copy()
+ frames_list_np.extend([last_frame_to_pad for _ in range(pad_frames)])
+
+ if not frames_list_np:
+ dprint_func(f"Error: No frames available to write for RIFE video (num_rife_frames: {num_frames}).")
+ return False
+
+ output_path_obj = Path(output_path)
+ video_written = create_video_from_frames_list(frames_list_np, output_path_obj, fps, resolution_wh)
+
+ if video_written:
+ dprint_func(f"RIFE video saved to: {video_written.resolve()}")
+ return True
+ else:
+ dprint_func(f"RIFE output file missing or empty after writing attempt: {output_path_obj}")
+ return False
+
+ except Exception as e:
+ dprint_func(f"RIFE interpolation failed with exception: {e}")
+ traceback.print_exc()
+ return False
+
+def prepare_vace_ref_for_segment(
+ ref_instruction: dict,
+ segment_processing_dir: Path,
+ target_resolution_wh: tuple[int, int] | None,
+ image_download_dir: Path | str | None = None,
+ task_id_for_logging: str | None = "generic_headless_task"
+) -> Path | None:
+ '''
+ Prepares a VACE reference image for a segment based on the given instruction.
+ Downloads the image if 'original_path' is a URL and image_download_dir is provided.
+ Applies strength adjustment and resizes, saving the result to segment_processing_dir.
+ Returns the path to the processed image if successful, or None otherwise.
+ '''
+ dprint(f"Task {task_id_for_logging} (prepare_vace_ref): VACE Ref instruction: {ref_instruction}, download_dir: {image_download_dir}")
+
+ original_image_path_str = ref_instruction.get("original_path")
+ strength_to_apply = ref_instruction.get("strength_to_apply")
+
+ if not original_image_path_str:
+ dprint(f"Task {task_id_for_logging}, Segment {segment_processing_dir.name}: No original_path in VACE ref instruction. Skipping.")
+ return None
+
+ local_original_image_path_str = download_image_if_url(original_image_path_str, image_download_dir, task_id_for_logging)
+ local_original_image_path = Path(local_original_image_path_str)
+
+ if not local_original_image_path.exists():
+ dprint(f"Task {task_id_for_logging}, Segment {segment_processing_dir.name}: VACE ref original image not found (after potential download): {local_original_image_path} (original input: {original_image_path_str})")
+ return None
+
+ vace_ref_type = ref_instruction.get("type", "generic")
+ segment_idx_for_naming = ref_instruction.get("segment_idx_for_naming", "unknown_idx")
+ processed_vace_base_name = f"vace_ref_s{segment_idx_for_naming}_{vace_ref_type}_str{strength_to_apply:.2f}"
+ original_suffix = local_original_image_path.suffix if local_original_image_path.suffix else ".png"
+
+ output_path_for_processed_vace = sm_get_unique_target_path(segment_processing_dir, processed_vace_base_name, original_suffix)
+
+ effective_target_resolution_wh = None
+ if target_resolution_wh:
+ effective_target_resolution_wh = ((target_resolution_wh[0] // 16) * 16, (target_resolution_wh[1] // 16) * 16)
+ if effective_target_resolution_wh != target_resolution_wh:
+ dprint(f"Task {task_id_for_logging}, Segment {segment_processing_dir.name}: Adjusted VACE ref target resolution from {target_resolution_wh} to {effective_target_resolution_wh}")
+
+ final_processed_path = sm_apply_strength_to_image(
+ image_path_input=local_original_image_path,
+ strength=strength_to_apply,
+ output_path=output_path_for_processed_vace,
+ target_resolution_wh=effective_target_resolution_wh,
+ task_id_for_logging=task_id_for_logging,
+ image_download_dir=None
+ )
+
+ if final_processed_path and final_processed_path.exists():
+ dprint(f"Task {task_id_for_logging}, Segment {segment_processing_dir.name}: Prepared VACE ref: {final_processed_path}")
+ return final_processed_path
+ else:
+ dprint(f"Task {task_id_for_logging}, Segment {segment_processing_dir.name}: Failed to apply strength/save VACE ref from {local_original_image_path}. Skipping.")
+ traceback.print_exc()
+ return None
+
+def create_guide_video_for_travel_segment(
+ segment_idx_for_logging: int,
+ end_anchor_image_index: int,
+ is_first_segment_from_scratch: bool,
+ total_frames_for_segment: int,
+ parsed_res_wh: tuple[int, int],
+ fps_helpers: int,
+ input_images_resolved_for_guide: list[str],
+ path_to_previous_segment_video_output_for_guide: str | None,
+ output_target_dir: Path,
+ guide_video_base_name: str,
+ segment_image_download_dir: Path | None,
+ task_id_for_logging: str,
+ full_orchestrator_payload: dict,
+ segment_params: dict,
+ single_image_journey: bool = False,
+ *,
+ dprint=print
+) -> Path | None:
+ """Creates the guide video for a travel segment with all fading and adjustments."""
+ try:
+ actual_guide_video_path = sm_get_unique_target_path(output_target_dir, guide_video_base_name, ".mp4")
+ gray_frame_bgr = sm_create_color_frame(parsed_res_wh, (128, 128, 128))
+
+ fade_in_p = json.loads(full_orchestrator_payload["fade_in_params_json_str"])
+ fade_out_p = json.loads(full_orchestrator_payload["fade_out_params_json_str"])
+ strength_adj = segment_params.get("subsequent_starting_strength_adjustment", 0.0)
+ desat_factor = segment_params.get("desaturate_subsequent_starting_frames", 0.0)
+ bright_adj = segment_params.get("adjust_brightness_subsequent_starting_frames", 0.0)
+ frame_overlap_from_previous = segment_params.get("frame_overlap_from_previous", 0)
+
+ fi_low, fi_high, fi_curve, fi_factor = float(fade_in_p.get("low_point",0)), float(fade_in_p.get("high_point",1)), str(fade_in_p.get("curve_type","ease_in_out")), float(fade_in_p.get("duration_factor",0))
+ fo_low, fo_high, fo_curve, fo_factor = float(fade_out_p.get("low_point",0)), float(fade_out_p.get("high_point",1)), str(fade_out_p.get("curve_type","ease_in_out")), float(fade_out_p.get("duration_factor",0))
+
+ if total_frames_for_segment <= 0:
+ dprint(f"Task {task_id_for_logging}: Guide video has 0 frames. Skipping creation.")
+ return None
+
+ dprint(f"Task {task_id_for_logging}: Interpolating guide video with {total_frames_for_segment} frames...")
+ frames_for_guide_list = [sm_create_color_frame(parsed_res_wh, (128,128,128)).copy() for _ in range(total_frames_for_segment)]
+
+ end_anchor_frame_np = None
+
+ if not single_image_journey:
+ end_anchor_img_path_str: str
+ if end_anchor_image_index < len(input_images_resolved_for_guide):
+ end_anchor_img_path_str = input_images_resolved_for_guide[end_anchor_image_index]
+ else:
+ raise ValueError(f"Seg {segment_idx_for_logging}: End anchor index {end_anchor_image_index} out of bounds for input images list ({len(input_images_resolved_for_guide)} images available).")
+
+ end_anchor_frame_np = sm_image_to_frame(end_anchor_img_path_str, parsed_res_wh, task_id_for_logging=task_id_for_logging, image_download_dir=segment_image_download_dir)
+ if end_anchor_frame_np is None: raise ValueError(f"Failed to load end anchor image: {end_anchor_img_path_str}")
+
+ num_end_anchor_duplicates = 1
+ start_anchor_frame_np = None
+
+ if is_first_segment_from_scratch:
+ start_anchor_img_path_str = input_images_resolved_for_guide[0]
+ start_anchor_frame_np = sm_image_to_frame(start_anchor_img_path_str, parsed_res_wh, task_id_for_logging=task_id_for_logging, image_download_dir=segment_image_download_dir)
+ if start_anchor_frame_np is None: raise ValueError(f"Failed to load start anchor: {start_anchor_img_path_str}")
+ if frames_for_guide_list: frames_for_guide_list[0] = start_anchor_frame_np.copy()
+
+ if single_image_journey:
+ dprint(f"Task {task_id_for_logging}: Guide video for single image journey. Only first frame is set.")
+ else:
+ # This is the original logic for fading between start and end.
+ pot_max_idx_start_fade = total_frames_for_segment - num_end_anchor_duplicates - 1
+ avail_frames_start_fade = max(0, pot_max_idx_start_fade)
+ num_start_fade_steps = int(avail_frames_start_fade * fo_factor)
+ if num_start_fade_steps > 0:
+ easing_fn_out = sm_get_easing_function(fo_curve)
+ for k_fo in range(num_start_fade_steps):
+ idx_in_guide = 1 + k_fo
+ if idx_in_guide >= total_frames_for_segment: break
+ alpha_lin = 1.0 - ((k_fo + 1) / float(num_start_fade_steps))
+ e_alpha = fo_low + (fo_high - fo_low) * easing_fn_out(alpha_lin)
+ frames_for_guide_list[idx_in_guide] = cv2.addWeighted(frames_for_guide_list[idx_in_guide].astype(np.float32), 1.0 - e_alpha, start_anchor_frame_np.astype(np.float32), e_alpha, 0).astype(np.uint8)
+
+ min_idx_end_fade = 1
+ max_idx_end_fade = total_frames_for_segment - num_end_anchor_duplicates - 1
+ avail_frames_end_fade = max(0, max_idx_end_fade - min_idx_end_fade + 1)
+ num_end_fade_steps = int(avail_frames_end_fade * fi_factor)
+ if num_end_fade_steps > 0:
+ actual_end_fade_start_idx = max(min_idx_end_fade, max_idx_end_fade - num_end_fade_steps + 1)
+ easing_fn_in = sm_get_easing_function(fi_curve)
+ for k_fi in range(num_end_fade_steps):
+ idx_in_guide = actual_end_fade_start_idx + k_fi
+ if idx_in_guide >= total_frames_for_segment: break
+ alpha_lin = (k_fi + 1) / float(num_end_fade_steps)
+ e_alpha = fi_low + (fi_high - fi_low) * easing_fn_in(alpha_lin)
+ base_f = frames_for_guide_list[idx_in_guide]
+ frames_for_guide_list[idx_in_guide] = cv2.addWeighted(base_f.astype(np.float32), 1.0 - e_alpha, end_anchor_frame_np.astype(np.float32), e_alpha, 0).astype(np.uint8)
+ elif fi_factor > 0 and avail_frames_end_fade > 0:
+ for k_fill in range(min_idx_end_fade, max_idx_end_fade + 1):
+ if k_fill < total_frames_for_segment: frames_for_guide_list[k_fill] = end_anchor_frame_np.copy()
+
+ elif path_to_previous_segment_video_output_for_guide: # Continued or Subsequent
+ dprint(f"GuideBuilder (Seg {segment_idx_for_logging}): Subsequent segment logic started.")
+ dprint(f"GuideBuilder: Prev video path: {path_to_previous_segment_video_output_for_guide}")
+ dprint(f"GuideBuilder: Overlap from prev setting: {frame_overlap_from_previous}")
+ dprint(f"GuideBuilder: Prev video exists: {Path(path_to_previous_segment_video_output_for_guide).exists()}")
+
+ if not Path(path_to_previous_segment_video_output_for_guide).exists():
+ raise ValueError(f"Previous video path does not exist: {path_to_previous_segment_video_output_for_guide}")
+
+ # Wait for file to be stable before reading (important for recently encoded videos)
+ dprint(f"GuideBuilder: Waiting for previous video file to stabilize...")
+ file_stable = sm_wait_for_file_stable(path_to_previous_segment_video_output_for_guide, checks=3, interval=1.0, dprint=dprint)
+ if not file_stable:
+ dprint(f"GuideBuilder: WARNING - File stability check failed, proceeding anyway")
+
+ # Get the expected frame count for the previous segment from orchestrator data
+ expected_prev_segment_frames = None
+ if segment_idx_for_logging > 0 and full_orchestrator_payload:
+ segment_frames_expanded = full_orchestrator_payload.get("segment_frames_expanded", [])
+ if segment_idx_for_logging - 1 < len(segment_frames_expanded):
+ expected_prev_segment_frames = segment_frames_expanded[segment_idx_for_logging - 1]
+ dprint(f"GuideBuilder: Previous segment expected to have {expected_prev_segment_frames} frames based on orchestrator data")
+
+ # If we have the expected frame count, use it directly
+ if expected_prev_segment_frames and expected_prev_segment_frames > 0:
+ dprint(f"GuideBuilder: Using known frame count {expected_prev_segment_frames} from orchestrator data")
+ prev_vid_total_frames = expected_prev_segment_frames
+
+ # Calculate overlap frames to extract
+ actual_overlap_to_use = min(frame_overlap_from_previous, prev_vid_total_frames)
+ start_extraction_idx = max(0, prev_vid_total_frames - actual_overlap_to_use)
+ dprint(f"GuideBuilder: Extracting {actual_overlap_to_use} frames starting from index {start_extraction_idx}")
+
+ # Extract the frames directly
+ overlap_frames_raw = extract_frames_from_video(path_to_previous_segment_video_output_for_guide, start_extraction_idx, actual_overlap_to_use, dprint_func=dprint)
+
+ # Verify we got the expected number of frames
+ if len(overlap_frames_raw) != actual_overlap_to_use:
+ dprint(f"GuideBuilder: WARNING - Expected {actual_overlap_to_use} frames but got {len(overlap_frames_raw)}. Falling back to manual extraction.")
+ # Fall back to extracting all frames
+ all_prev_frames = extract_frames_from_video(path_to_previous_segment_video_output_for_guide, dprint_func=dprint)
+ prev_vid_total_frames = len(all_prev_frames)
+ actual_overlap_to_use = min(frame_overlap_from_previous, prev_vid_total_frames)
+ overlap_frames_raw = all_prev_frames[-actual_overlap_to_use:] if actual_overlap_to_use > 0 else []
+ else:
+ # Fallback: No orchestrator data, use OpenCV or manual extraction
+ dprint(f"GuideBuilder: No orchestrator frame count data available, falling back to frame detection")
+ prev_vid_total_frames, prev_vid_fps = get_video_frame_count_and_fps(path_to_previous_segment_video_output_for_guide)
+ dprint(f"GuideBuilder: Frame count from cv2: {prev_vid_total_frames}, fps: {prev_vid_fps}")
+
+ if not prev_vid_total_frames: # Handles None or 0
+ dprint(f"GuideBuilder: Fallback triggered due to zero/None frame count. Manually reading frames.")
+ # Fallback: read all frames to determine length
+ all_prev_frames = extract_frames_from_video(path_to_previous_segment_video_output_for_guide, dprint_func=dprint)
+ prev_vid_total_frames = len(all_prev_frames)
+ dprint(f"GuideBuilder: Manual frame count from fallback: {prev_vid_total_frames}")
+ if prev_vid_total_frames == 0:
+ raise ValueError("Previous segment video appears to have zero frames – cannot build guide overlap.")
+ # Decide how many overlap frames we can reuse
+ actual_overlap_to_use = min(frame_overlap_from_previous, prev_vid_total_frames)
+ overlap_frames_raw = all_prev_frames[-actual_overlap_to_use:]
+ dprint(f"GuideBuilder: Using fallback - extracting last {actual_overlap_to_use} frames from {prev_vid_total_frames} total frames")
+ else:
+ dprint(f"GuideBuilder: Using cv2 frame count.")
+ actual_overlap_to_use = min(frame_overlap_from_previous, prev_vid_total_frames)
+ start_extraction_idx = max(0, prev_vid_total_frames - actual_overlap_to_use)
+ dprint(f"GuideBuilder: Extracting {actual_overlap_to_use} frames starting from index {start_extraction_idx}")
+ overlap_frames_raw = extract_frames_from_video(path_to_previous_segment_video_output_for_guide, start_extraction_idx, actual_overlap_to_use, dprint_func=dprint)
+
+ # Log the final overlap calculation
+ dprint(f"GuideBuilder: Calculated actual_overlap_to_use: {actual_overlap_to_use if 'actual_overlap_to_use' in locals() else 'Not calculated'}")
+ dprint(f"GuideBuilder: Extracted raw overlap frames count: {len(overlap_frames_raw) if 'overlap_frames_raw' in locals() else 'Not extracted'}")
+
+ # Check video resolution to understand if it matches our target
+ if overlap_frames_raw and len(overlap_frames_raw) > 0:
+ first_frame_shape = overlap_frames_raw[0].shape
+ prev_height, prev_width = first_frame_shape[0], first_frame_shape[1]
+ dprint(f"GuideBuilder: Previous video resolution from extracted frames: {prev_width}x{prev_height} (target: {parsed_res_wh[0]}x{parsed_res_wh[1]})")
+ if prev_width != parsed_res_wh[0] or prev_height != parsed_res_wh[1]:
+ dprint(f"GuideBuilder: Resolution mismatch detected! Previous video will be resized during guide creation.")
+
+ frames_read_for_overlap = 0
+ for k, frame_fp in enumerate(overlap_frames_raw):
+ if k >= total_frames_for_segment: break
+ original_shape = frame_fp.shape
+ if frame_fp.shape[1]!=parsed_res_wh[0] or frame_fp.shape[0]!=parsed_res_wh[1]:
+ frame_fp = cv2.resize(frame_fp, parsed_res_wh, interpolation=cv2.INTER_AREA)
+ dprint(f"GuideBuilder: Resized frame {k} from {original_shape} to {frame_fp.shape}")
+ frames_for_guide_list[k] = frame_fp.copy()
+ frames_read_for_overlap += 1
+
+ dprint(f"GuideBuilder: Frames copied into guide list: {frames_read_for_overlap}")
+
+ # Log details about what frames were actually placed in the guide
+ if frames_read_for_overlap > 0:
+ dprint(f"GuideBuilder: Guide frames 0-{frames_read_for_overlap-1} now contain frames from previous video")
+ dprint(f"GuideBuilder: Guide frames {frames_read_for_overlap}-{total_frames_for_segment-1} are still gray frames (will be modified by fade logic)")
+
+ if frames_read_for_overlap > 0:
+ if fo_factor > 0.0:
+ num_init_fade_steps = min(int(frames_read_for_overlap * fo_factor), frames_read_for_overlap)
+ easing_fn_fo_ol = sm_get_easing_function(fo_curve)
+ for k_fo_ol in range(num_init_fade_steps):
+ alpha_l = 1.0 - ((k_fo_ol + 1) / float(num_init_fade_steps))
+ eff_s = fo_low + (fo_high - fo_low) * easing_fn_fo_ol(alpha_l)
+ eff_s = np.clip(eff_s + strength_adj, 0, 1)
+ base_f=frames_for_guide_list[k_fo_ol]
+ frames_for_guide_list[k_fo_ol] = cv2.addWeighted(gray_frame_bgr.astype(np.float32),1-eff_s,base_f.astype(np.float32),eff_s,0).astype(np.uint8)
+ if desat_factor > 0:
+ g=cv2.cvtColor(frames_for_guide_list[k_fo_ol],cv2.COLOR_BGR2GRAY)
+ gb=cv2.cvtColor(g,cv2.COLOR_GRAY2BGR)
+ frames_for_guide_list[k_fo_ol]=cv2.addWeighted(frames_for_guide_list[k_fo_ol],1-desat_factor,gb,desat_factor,0)
+ if bright_adj!=0:
+ frames_for_guide_list[k_fo_ol]=sm_adjust_frame_brightness(frames_for_guide_list[k_fo_ol],bright_adj)
+ else:
+ eff_s=np.clip(fo_high+strength_adj,0,1)
+ if abs(eff_s-1.0)>1e-5 or desat_factor>0 or bright_adj!=0:
+ for k_s_ol in range(frames_read_for_overlap):
+ base_f=frames_for_guide_list[k_s_ol];frames_for_guide_list[k_s_ol]=cv2.addWeighted(gray_frame_bgr.astype(np.float32),1-eff_s,base_f.astype(np.float32),eff_s,0).astype(np.uint8)
+ if desat_factor>0: g=cv2.cvtColor(frames_for_guide_list[k_s_ol],cv2.COLOR_BGR2GRAY);gb=cv2.cvtColor(g,cv2.COLOR_GRAY2BGR);frames_for_guide_list[k_s_ol]=cv2.addWeighted(frames_for_guide_list[k_s_ol],1-desat_factor,gb,desat_factor,0)
+ if bright_adj!=0: frames_for_guide_list[k_s_ol]=sm_adjust_frame_brightness(frames_for_guide_list[k_s_ol],bright_adj)
+
+ if not single_image_journey:
+ min_idx_efs = frames_read_for_overlap; max_idx_efs = total_frames_for_segment - num_end_anchor_duplicates - 1
+ avail_f_efs = max(0, max_idx_efs - min_idx_efs + 1); num_efs_steps = int(avail_f_efs * fi_factor)
+ if num_efs_steps > 0:
+ actual_efs_start_idx = max(min_idx_efs, max_idx_efs - num_efs_steps + 1)
+ easing_fn_in_s = sm_get_easing_function(fi_curve)
+ for k_fi_s in range(num_efs_steps):
+ idx = actual_efs_start_idx+k_fi_s
+ if idx >= total_frames_for_segment: break
+ if idx < min_idx_efs: continue
+ alpha_l=(k_fi_s+1)/float(num_efs_steps);e_alpha=fi_low+(fi_high-fi_low)*easing_fn_in_s(alpha_l);e_alpha=np.clip(e_alpha,0,1)
+ base_f=frames_for_guide_list[idx];frames_for_guide_list[idx]=cv2.addWeighted(base_f.astype(np.float32),1-e_alpha,end_anchor_frame_np.astype(np.float32),e_alpha,0).astype(np.uint8)
+ elif fi_factor > 0 and avail_f_efs > 0:
+ for k_fill in range(min_idx_efs, max_idx_efs + 1):
+ if k_fill < total_frames_for_segment: frames_for_guide_list[k_fill] = end_anchor_frame_np.copy()
+
+ if not single_image_journey and total_frames_for_segment > 0 and end_anchor_frame_np is not None:
+ for k_dup in range(min(num_end_anchor_duplicates, total_frames_for_segment)):
+ idx_s = total_frames_for_segment - 1 - k_dup
+ if idx_s >= 0: frames_for_guide_list[idx_s] = end_anchor_frame_np.copy()
+ else: break
+
+ if is_first_segment_from_scratch and total_frames_for_segment > 0 and start_anchor_frame_np is not None:
+ frames_for_guide_list[0] = start_anchor_frame_np.copy()
+
+ if frames_for_guide_list:
+ guide_video_file_path = create_video_from_frames_list(frames_for_guide_list, actual_guide_video_path, fps_helpers, parsed_res_wh)
+ if guide_video_file_path and guide_video_file_path.exists():
+ return guide_video_file_path
+
+ return None
+
+ except Exception as e:
+ dprint(f"ERROR creating guide video for segment {segment_idx_for_logging}: {e}")
+ traceback.print_exc()
+ return None
+
+def _cm_enhance_saturation(image_bgr, saturation_factor=0.5):
+ """
+ Adjust saturation of an image by the given factor.
+ saturation_factor: 1.0 = no change, 0.5 = 50% reduction, 1.3 = 30% increase, etc.
+ """
+ if not _COLOR_MATCH_DEPS_AVAILABLE: return image_bgr
+ hsv = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2HSV)
+ hsv_float = hsv.astype(np.float32)
+ h, s, v = cv2.split(hsv_float)
+ s_adjusted = s * saturation_factor
+ s_adjusted = np.clip(s_adjusted, 0, 255)
+ hsv_adjusted = cv2.merge([h, s_adjusted, v])
+ hsv_adjusted_uint8 = hsv_adjusted.astype(np.uint8)
+ adjusted_bgr = cv2.cvtColor(hsv_adjusted_uint8, cv2.COLOR_HSV2BGR)
+ return adjusted_bgr
+
+def _cm_transfer_mean_std_lab(source_bgr, target_bgr):
+ if not _COLOR_MATCH_DEPS_AVAILABLE: return source_bgr
+ MIN_ALLOWED_STD_RATIO_FOR_LUMINANCE = 0.1
+ MIN_ALLOWED_STD_RATIO_FOR_COLOR = 0.4
+ source_lab = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2LAB)
+ target_lab = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2LAB)
+
+ source_lab_float = source_lab.astype(np.float32)
+ target_lab_float = target_lab.astype(np.float32)
+
+ s_l, s_a, s_b = cv2.split(source_lab_float)
+ t_l, t_a, t_b = cv2.split(target_lab_float)
+
+ channels_out = []
+ for i, (s_chan, t_chan) in enumerate(zip([s_l, s_a, s_b], [t_l, t_a, t_b])):
+ s_mean_val, s_std_val = cv2.meanStdDev(s_chan)
+ t_mean_val, t_std_val = cv2.meanStdDev(t_chan)
+ s_mean, s_std = s_mean_val[0][0], s_std_val[0][0]
+ t_mean, t_std = t_mean_val[0][0], t_std_val[0][0]
+
+ std_ratio = t_std / s_std if s_std > 1e-5 else 1.0
+
+ min_ratio = MIN_ALLOWED_STD_RATIO_FOR_LUMINANCE if i == 0 else MIN_ALLOWED_STD_RATIO_FOR_COLOR
+ effective_std_ratio = max(std_ratio, min_ratio)
+
+ if s_std > 1e-5:
+ transformed_chan = (s_chan - s_mean) * effective_std_ratio + t_mean
+ else:
+ transformed_chan = np.full_like(s_chan, t_mean)
+
+ channels_out.append(transformed_chan)
+
+ result_lab_float = cv2.merge(channels_out)
+ result_lab_clipped = np.clip(result_lab_float, 0, 255)
+ result_lab_uint8 = result_lab_clipped.astype(np.uint8)
+ result_bgr = cv2.cvtColor(result_lab_uint8, cv2.COLOR_LAB2BGR)
+ return result_bgr
+
+def apply_color_matching_to_video(video_path: str, start_ref_path: str, end_ref_path: str, output_path: str, dprint):
+ if not all([_COLOR_MATCH_DEPS_AVAILABLE, Path(video_path).exists(), Path(start_ref_path).exists(), Path(end_ref_path).exists()]):
+ dprint(f"Color Matching: Skipping due to missing deps or files. Deps:{_COLOR_MATCH_DEPS_AVAILABLE}, Video:{Path(video_path).exists()}, Start:{Path(start_ref_path).exists()}, End:{Path(end_ref_path).exists()}")
+ return None
+
+ frames = extract_frames_from_video(video_path)
+ frame_count, fps = get_video_frame_count_and_fps(video_path)
+ if not frames or not frame_count or not fps:
+ dprint("Color Matching: Frame extraction or metadata retrieval failed.")
+ return None
+
+ # Get resolution from the first frame
+ h, w, _ = frames[0].shape
+ resolution = (w, h)
+
+ start_ref_bgr = cv2.imread(start_ref_path)
+ end_ref_bgr = cv2.imread(end_ref_path)
+ start_ref_resized = cv2.resize(start_ref_bgr, resolution)
+ end_ref_resized = cv2.resize(end_ref_bgr, resolution)
+
+ total_frames = len(frames)
+ accumulated_frames = []
+
+ for i, frame_bgr in enumerate(frames):
+ frame_bgr_desaturated = _cm_enhance_saturation(frame_bgr, saturation_factor=0.5)
+
+ corrected_start_bgr = _cm_transfer_mean_std_lab(frame_bgr_desaturated, start_ref_resized)
+ corrected_end_bgr = _cm_transfer_mean_std_lab(frame_bgr_desaturated, end_ref_resized)
+
+ t = i / (total_frames - 1) if total_frames > 1 else 1.0
+ w_original = (0.5 * t) if t < 0.5 else (0.5 - 0.5 * t)
+ w_correct = 1.0 - w_original
+ w_start = (1.0 - t) * w_correct
+ w_end = t * w_correct
+
+ blend_float = (w_start * corrected_start_bgr.astype(np.float32) +
+ w_end * corrected_end_bgr.astype(np.float32) +
+ w_original * frame_bgr.astype(np.float32))
+
+ blended_frame_bgr = np.clip(blend_float, 0, 255).astype(np.uint8)
+ accumulated_frames.append(blended_frame_bgr)
+
+ if accumulated_frames:
+ created_video_path = create_video_from_frames_list(accumulated_frames, output_path, fps, resolution)
+ dprint(f"Color Matching: Successfully created color matched video at {created_video_path}")
+ return created_video_path
+
+ dprint("Color Matching: Failed to produce any frames.")
+ return None
+
+def extract_last_frame_as_image(video_path: str | Path, output_dir: Path, task_id_for_log: str) -> str | None:
+ """
+ Extracts the last frame of a video and saves it as a PNG image.
+ """
+ if not _COLOR_MATCH_DEPS_AVAILABLE:
+ dprint(f"Task {task_id_for_log} extract_last_frame_as_image: Skipping due to missing CV2/Numpy dependencies.")
+ return None
+ try:
+ cap = cv2.VideoCapture(str(video_path))
+ if not cap.isOpened():
+ dprint(f"[ERROR Task {task_id_for_log}] extract_last_frame_as_image: Could not open video {video_path}")
+ return None
+
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ if frame_count <= 0:
+ cap.release()
+ dprint(f"Task {task_id_for_log} extract_last_frame_as_image: Video has 0 frames {video_path}")
+ return None
+
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
+ ret, frame = cap.read()
+ cap.release()
+
+ if ret:
+ output_path = output_dir / f"last_frame_ref_{Path(video_path).stem}.png"
+ cv2.imwrite(str(output_path), frame)
+ return str(output_path.resolve())
+ dprint(f"Task {task_id_for_log} extract_last_frame_as_image: Failed to read last frame from {video_path}")
+ return None
+ except Exception as e:
+ dprint(f"[ERROR Task {task_id_for_log}] extract_last_frame_as_image: Exception extracting frame from {video_path}: {e}")
+ traceback.print_exc()
+ return None
+
+# --- New utility: Overlay input images above video (Start & End side-by-side) ---
+
+def overlay_start_end_images_above_video(
+ start_image_path: str | Path,
+ end_image_path: str | Path,
+ input_video_path: str | Path,
+ output_video_path: str | Path,
+ *,
+ dprint=print,
+) -> bool:
+ """Creates a composite video that shows *start_image* (left) and *end_image* (right)
+ on a row above the *input_video*.
+
+ Layout:
+ START | END (static images, full video duration)
+ ----------------------
+ VIDEO (original video frames)
+
+ The resulting video keeps the original width of *input_video*. Each image is
+ scaled to exactly half that width and the same height as the video to ensure
+ perfect alignment. The final output therefore has a height of
+ ``video_height * 2`` and a width equal to ``video_width``.
+
+ Args:
+ start_image_path: Path to the starting image.
+ end_image_path: Path to the ending image.
+ input_video_path: Source video that was generated.
+ output_video_path: Desired path for the composite video.
+ dprint: Debug print function.
+
+ Returns:
+ True if the composite video was created successfully, else False.
+ """
+ try:
+ start_image_path = Path(start_image_path)
+ end_image_path = Path(end_image_path)
+ input_video_path = Path(input_video_path)
+ output_video_path = Path(output_video_path)
+
+ if not (start_image_path.exists() and end_image_path.exists() and input_video_path.exists()):
+ dprint(
+ f"overlay_start_end_images_above_video: One or more input paths are missing.\n"
+ f" start_image_path = {start_image_path}\n"
+ f" end_image_path = {end_image_path}\n"
+ f" input_video_path = {input_video_path}"
+ )
+ return False
+
+ # ---------------------------------------------------------
+ # Preferred implementation: MoviePy (simpler, robust)
+ # ---------------------------------------------------------
+ if _MOVIEPY_AVAILABLE:
+ try:
+ video_clip = mpe.VideoFileClip(str(input_video_path))
+
+ half_width_px = int(video_clip.w / 2)
+
+ img1_clip = mpe.ImageClip(str(start_image_path)).resize(width=half_width_px).set_duration(video_clip.duration)
+ img2_clip = mpe.ImageClip(str(end_image_path)).resize(width=half_width_px).set_duration(video_clip.duration)
+
+ # top row (images side-by-side)
+ top_row = mpe.clips_array([[img1_clip, img2_clip]])
+
+ # Build composite video
+ final_h = top_row.h + video_clip.h
+ composite = mpe.CompositeVideoClip([
+ top_row.set_position((0, 0)),
+ video_clip.set_position((0, top_row.h))
+ ], size=(video_clip.w, final_h))
+
+ # Write video
+ composite.write_videofile(
+ str(output_video_path.with_suffix('.mp4')),
+ codec="libx264",
+ audio=False,
+ fps=video_clip.fps or fps,
+ preset="veryfast",
+ )
+
+ video_clip.close(); img1_clip.close(); img2_clip.close(); composite.close()
+
+ if output_video_path.exists() and output_video_path.stat().st_size > 0:
+ return True
+ else:
+ dprint("overlay_start_end_images_above_video: MoviePy output missing after write.")
+ except Exception as e_mov:
+ dprint(f"overlay_start_end_images_above_video: MoviePy path failed – {e_mov}. Falling back to ffmpeg.")
+
+ # ---------------------------------------------------------
+ # Fallback: FFmpeg filter_complex (no MoviePy or MoviePy failed)
+ # ---------------------------------------------------------
+ if not _MOVIEPY_AVAILABLE:
+ try:
+ # ---------------------------------------------------------
+ # Determine the resolution of the **input video**
+ # ---------------------------------------------------------
+ cap = cv2.VideoCapture(str(input_video_path))
+ if not cap.isOpened():
+ dprint(f"overlay_start_end_images_above_video: Could not open video {input_video_path}")
+ return False
+ video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ cap.release()
+
+ if video_width == 0 or video_height == 0:
+ dprint(
+ f"overlay_start_end_images_above_video: Failed to read resolution from {input_video_path}"
+ )
+ return False
+
+ half_width = video_width // 2 # Integer division for even-pixel alignment
+
+ # Ensure output directory exists
+ output_video_path = output_video_path.with_suffix('.mp4')
+ output_video_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # ---------------------------------------------------------
+ # Build & run ffmpeg command
+ # ---------------------------------------------------------
+ # 1) Scale both images to half width / full height
+ # 2) hstack -> side-by-side static banner (labelled [top])
+ # 3) vstack banner with original video (labelled [vid])
+ # NOTE: We explicitly scale the *video* as well into label [vid]
+ # to guarantee width/height match. This is mostly defensive –
+ # if the video is already the desired size the scale is a NOP.
+
+ filter_complex = (
+ f"[1:v]scale={half_width}:{video_height}[left];" # scale start img
+ f"[2:v]scale={half_width}:{video_height}[right];" # scale end img
+ f"[left][right]hstack=inputs=2[top];" # combine images
+ f"[0:v]scale={video_width}:{video_height}[vid];" # ensure video size
+ f"[top][vid]vstack=inputs=2[output]" # stack banner + video
+ )
+
+ # Determine FPS from the input video for consistent output
+ fps = 0.0
+ try:
+ cap2 = cv2.VideoCapture(str(input_video_path))
+ if cap2.isOpened():
+ fps = cap2.get(cv2.CAP_PROP_FPS)
+ cap2.release()
+ except Exception:
+ fps = 0.0
+ if fps is None or fps <= 0.1:
+ fps = 16 # sensible default
+
+ ffmpeg_cmd = [
+ "ffmpeg", "-y", # overwrite output
+ "-loglevel", "error",
+ "-i", str(input_video_path),
+ "-loop", "1", "-i", str(start_image_path),
+ "-loop", "1", "-i", str(end_image_path),
+ "-filter_complex", filter_complex,
+ "-map", "[output]",
+ "-r", str(int(round(fps))), # set output fps
+ "-shortest", # stop when primary video stream ends
+ "-c:v", "libx264", "-pix_fmt", "yuv420p",
+ "-movflags", "+faststart",
+ str(output_video_path.resolve()),
+ ]
+
+ dprint(f"overlay_start_end_images_above_video: Running ffmpeg to create composite video.\nCommand: {' '.join(ffmpeg_cmd)}")
+
+ proc = subprocess.run(ffmpeg_cmd, capture_output=True)
+ if proc.returncode != 0:
+ dprint(
+ f"overlay_start_end_images_above_video: ffmpeg failed (returncode={proc.returncode}).\n"
+ f"stderr: {proc.stderr.decode(errors='ignore')[:500]}"
+ )
+ # Clean up partially written file if any
+ if output_video_path.exists():
+ try:
+ output_video_path.unlink()
+ except Exception:
+ pass
+ return False
+
+ if not output_video_path.exists() or output_video_path.stat().st_size == 0:
+ dprint(
+ f"overlay_start_end_images_above_video: Output video not created or empty at {output_video_path}"
+ )
+ return False
+
+ return True
+
+ except Exception as e_ffmpeg:
+ dprint(f"overlay_start_end_images_above_video: ffmpeg failed – {e_ffmpeg}. Falling back to MoviePy.")
+ return False
+
+ except Exception as e_ov:
+ dprint(f"overlay_start_end_images_above_video: Exception – {e_ov}")
+ traceback.print_exc()
+ try:
+ if output_video_path and output_video_path.exists():
+ output_video_path.unlink()
+ except Exception:
+ pass
+ return False
\ No newline at end of file
diff --git a/source/wgp_utils.py b/source/wgp_utils.py
new file mode 100644
index 000000000..66c1e19c7
--- /dev/null
+++ b/source/wgp_utils.py
@@ -0,0 +1,319 @@
+"""Utility helpers around Wan2GP.wgp API used by headless tasks."""
+
+from __future__ import annotations
+
+import json
+import time
+from pathlib import Path
+from typing import Any, Tuple, Dict, Optional, Union, Callable
+from PIL import Image
+
+
+def _ensure_lora_in_lists(lora_name: str, multiplier: str, activated_loras: list, loras_multipliers: Union[list, str]) -> Tuple[list, Union[list, str]]:
+ """
+ Helper function to ensure a LoRA is present in activated_loras with the correct multiplier.
+ Returns updated (activated_loras, loras_multipliers) tuple.
+ """
+ if lora_name not in activated_loras:
+ activated_loras.insert(0, lora_name)
+ if isinstance(loras_multipliers, list):
+ loras_multipliers.insert(0, multiplier)
+ elif isinstance(loras_multipliers, str):
+ mult_list = [m.strip() for m in loras_multipliers.split(",") if m.strip()] if loras_multipliers else []
+ mult_list.insert(0, multiplier)
+ loras_multipliers = ",".join(mult_list)
+ else:
+ loras_multipliers = [multiplier]
+ return activated_loras, loras_multipliers
+
+
+def _set_param_if_different(params: dict, key: str, target_value: Union[int, float], task_id: str, lora_name: str, dprint: Callable):
+ """
+ Helper function to set a parameter value if it differs from the target, with logging.
+ """
+ current_value = params.get(key)
+ if isinstance(target_value, int):
+ try:
+ current_int = int(current_value) if current_value is not None else 0
+ except (TypeError, ValueError):
+ current_int = 0
+ if current_int != target_value:
+ dprint(f"{task_id}: {lora_name} active – overriding {key} {current_int} → {target_value}")
+ params[key] = target_value
+ elif isinstance(target_value, float):
+ current_float = float(current_value) if current_value is not None else 0.0
+ if current_float != target_value:
+ dprint(f"{task_id}: {lora_name} active – setting {key} → {target_value}")
+ params[key] = target_value
+
+
+def _normalize_loras_multipliers_format(loras_multipliers: Union[list, str]) -> str:
+ """
+ Helper function to normalize loras_multipliers to string format expected by WGP.
+ """
+ if isinstance(loras_multipliers, list):
+ return ",".join(loras_multipliers) if loras_multipliers else ""
+ else:
+ return loras_multipliers if loras_multipliers else ""
+
+
+def generate_single_video(
+ wgp_mod,
+ task_id: str,
+ prompt: str,
+ negative_prompt: str = "",
+ resolution: str = "512x512",
+ video_length: int = 81,
+ seed: int = 12345,
+ num_inference_steps: int = 30,
+ guidance_scale: float = 5.0,
+ flow_shift: float = 3.0,
+ model_filename: str = None,
+ video_guide: str = None,
+ video_mask: str = None,
+ image_refs: list = None,
+ use_causvid_lora: bool = False,
+ use_lighti2x_lora: bool = False,
+ apply_reward_lora: bool = False,
+ additional_loras: dict = None,
+ video_prompt_type: str = "T",
+ dprint = None,
+ **kwargs
+) -> tuple[bool, str]:
+ """
+ Centralized wrapper for WGP video generation with comprehensive debugging.
+ Returns (success_bool, output_path_or_error_message)
+ """
+ if dprint is None:
+ dprint = lambda x: print(f"[DEBUG] {x}")
+
+ print(f"[WGP_GENERATION_DEBUG] Starting generation for task {task_id}")
+ print(f"[WGP_GENERATION_DEBUG] Parameters:")
+ print(f"[WGP_GENERATION_DEBUG] prompt: {prompt}")
+ print(f"[WGP_GENERATION_DEBUG] resolution: {resolution}")
+ print(f"[WGP_GENERATION_DEBUG] video_length: {video_length}")
+ print(f"[WGP_GENERATION_DEBUG] seed: {seed}")
+ print(f"[WGP_GENERATION_DEBUG] num_inference_steps: {num_inference_steps}")
+ print(f"[WGP_GENERATION_DEBUG] guidance_scale: {guidance_scale}")
+ print(f"[WGP_GENERATION_DEBUG] flow_shift: {flow_shift}")
+ print(f"[WGP_GENERATION_DEBUG] use_causvid_lora: {use_causvid_lora}")
+ print(f"[WGP_GENERATION_DEBUG] use_lighti2x_lora: {use_lighti2x_lora}")
+ print(f"[WGP_GENERATION_DEBUG] video_guide: {video_guide}")
+ print(f"[WGP_GENERATION_DEBUG] video_mask: {video_mask}")
+ print(f"[WGP_GENERATION_DEBUG] video_prompt_type: {video_prompt_type}")
+
+ try:
+ # Analyze input guide video if provided
+ if video_guide and Path(video_guide).exists():
+ from .common_utils import get_video_frame_count_and_fps
+ guide_frames, guide_fps = get_video_frame_count_and_fps(video_guide)
+ print(f"[WGP_GENERATION_DEBUG] Guide video analysis:")
+ print(f"[WGP_GENERATION_DEBUG] Path: {video_guide}")
+ print(f"[WGP_GENERATION_DEBUG] Frames: {guide_frames}")
+ print(f"[WGP_GENERATION_DEBUG] FPS: {guide_fps}")
+ if guide_frames != video_length:
+ print(f"[WGP_GENERATION_DEBUG] ⚠️ GUIDE/TARGET MISMATCH! Guide has {guide_frames} frames, target is {video_length}")
+
+ # Analyze input mask video if provided
+ if video_mask and Path(video_mask).exists():
+ from .common_utils import get_video_frame_count_and_fps
+ mask_frames, mask_fps = get_video_frame_count_and_fps(video_mask)
+ print(f"[WGP_GENERATION_DEBUG] Mask video analysis:")
+ print(f"[WGP_GENERATION_DEBUG] Path: {video_mask}")
+ print(f"[WGP_GENERATION_DEBUG] Frames: {mask_frames}")
+ print(f"[WGP_GENERATION_DEBUG] FPS: {mask_fps}")
+ if mask_frames != video_length:
+ print(f"[WGP_GENERATION_DEBUG] ⚠️ MASK/TARGET MISMATCH! Mask has {mask_frames} frames, target is {video_length}")
+
+ # Build task state for WGP
+ from .common_utils import build_task_state
+
+ task_params_dict = {
+ "task_id": task_id,
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ "resolution": resolution,
+ "video_length": video_length,
+ "frames": video_length, # Alternative key
+ "seed": seed,
+ "num_inference_steps": num_inference_steps,
+ "guidance_scale": guidance_scale,
+ "flow_shift": flow_shift,
+ "video_guide_path": video_guide,
+ "video_mask": video_mask,
+ "image_refs_paths": image_refs,
+ "video_prompt_type": video_prompt_type,
+ "use_causvid_lora": use_causvid_lora,
+ "use_lighti2x_lora": use_lighti2x_lora,
+ "apply_reward_lora": apply_reward_lora,
+ "processed_additional_loras": additional_loras or {},
+ **kwargs
+ }
+
+ print(f"[WGP_GENERATION_DEBUG] Built task_params_dict with {len(task_params_dict)} parameters")
+
+ # Get LoRA directory and setup
+ lora_dir_for_active_model = wgp_mod.get_lora_dir(model_filename)
+ all_loras_for_active_model, _, _, _, _, _, _ = wgp_mod.setup_loras(
+ model_filename, None, lora_dir_for_active_model, "", None
+ )
+
+ print(f"[WGP_GENERATION_DEBUG] LoRA setup complete. Available LoRAs: {len(all_loras_for_active_model) if all_loras_for_active_model else 0}")
+
+ # Build state and UI params
+ state, ui_params = build_task_state(
+ wgp_mod,
+ model_filename,
+ task_params_dict,
+ all_loras_for_active_model,
+ None, # image_download_dir
+ apply_reward_lora=apply_reward_lora
+ )
+
+ print(f"[WGP_GENERATION_DEBUG] State and UI params built")
+ print(f"[WGP_GENERATION_DEBUG] Final ui_params video_length: {ui_params.get('video_length', 'NOT_SET')}")
+ print(f"[WGP_GENERATION_DEBUG] Final ui_params frames: {ui_params.get('frames', 'NOT_SET')}")
+
+ # Create temporary output directory
+ import tempfile
+ temp_output_dir = tempfile.mkdtemp(prefix=f"wgp_single_{task_id}_")
+ print(f"[WGP_GENERATION_DEBUG] Using temporary output directory: {temp_output_dir}")
+
+ # Save original path and set temporary
+ original_wgp_save_path = wgp_mod.save_path
+ wgp_mod.save_path = str(temp_output_dir)
+
+ try:
+ # Create task and send_cmd objects
+ gen_task_placeholder = {
+ "id": 1,
+ "prompt": ui_params.get("prompt"),
+ "params": {
+ "model_filename_from_gui_state": model_filename,
+ "model": kwargs.get("model", "t2v")
+ }
+ }
+
+ def send_cmd_debug(cmd, data=None):
+ if cmd == "progress":
+ if isinstance(data, list) and len(data) >= 2:
+ prog, txt = data[0], data[1]
+ if isinstance(prog, tuple) and len(prog) == 2:
+ step, total = prog
+ print(f"[WGP_PROGRESS] {step}/{total} – {txt}")
+ else:
+ print(f"[WGP_PROGRESS] {txt}")
+ elif cmd == "status":
+ print(f"[WGP_STATUS] {data}")
+ elif cmd == "info":
+ print(f"[WGP_INFO] {data}")
+ elif cmd == "error":
+ print(f"[WGP_ERROR] {data}")
+ raise RuntimeError(f"WGP error for {task_id}: {data}")
+ elif cmd == "output":
+ print(f"[WGP_OUTPUT] Video generation completed")
+
+ print(f"[WGP_GENERATION_DEBUG] Calling wgp_mod.generate_video...")
+ print(f"[WGP_GENERATION_DEBUG] Final parameters being passed:")
+ print(f"[WGP_GENERATION_DEBUG] video_length: {ui_params.get('video_length')}")
+ print(f"[WGP_GENERATION_DEBUG] resolution: {ui_params.get('resolution')}")
+ print(f"[WGP_GENERATION_DEBUG] seed: {ui_params.get('seed')}")
+ print(f"[WGP_GENERATION_DEBUG] num_inference_steps: {ui_params.get('num_inference_steps')}")
+
+ # Call the actual WGP generation
+ wgp_mod.generate_video(
+ task=gen_task_placeholder,
+ send_cmd=send_cmd_debug,
+ prompt=ui_params["prompt"],
+ negative_prompt=ui_params.get("negative_prompt", ""),
+ resolution=ui_params["resolution"],
+ video_length=ui_params.get("video_length", video_length),
+ seed=ui_params["seed"],
+ num_inference_steps=ui_params.get("num_inference_steps", num_inference_steps),
+ guidance_scale=ui_params.get("guidance_scale", guidance_scale),
+ flow_shift=ui_params.get("flow_shift", flow_shift),
+ video_guide=ui_params.get("video_guide"),
+ video_mask=ui_params.get("video_mask"),
+ image_refs=ui_params.get("image_refs"),
+ video_prompt_type=ui_params.get("video_prompt_type", video_prompt_type),
+ activated_loras=ui_params.get("activated_loras", []),
+ loras_multipliers=ui_params.get("loras_multipliers", ""),
+ state=state,
+ model_filename=model_filename,
+ # Add other parameters as needed
+ audio_guidance_scale=ui_params.get("audio_guidance_scale", 5.0),
+ embedded_guidance_scale=ui_params.get("embedded_guidance_scale", 6.0),
+ repeat_generation=ui_params.get("repeat_generation", 1),
+ multi_images_gen_type=ui_params.get("multi_images_gen_type", 0),
+ tea_cache_setting=ui_params.get("tea_cache_setting", 0.0),
+ tea_cache_start_step_perc=ui_params.get("tea_cache_start_step_perc", 0),
+ image_prompt_type=ui_params.get("image_prompt_type", "T"),
+ image_start=[wgp_mod.convert_image(img) for img in ui_params.get("image_start", [])],
+ image_end=[wgp_mod.convert_image(img) for img in ui_params.get("image_end", [])],
+ model_mode=ui_params.get("model_mode", 0),
+ video_source=ui_params.get("video_source"),
+ keep_frames_video_source=ui_params.get("keep_frames_video_source", ""),
+ keep_frames_video_guide=ui_params.get("keep_frames_video_guide", ""),
+ audio_guide=ui_params.get("audio_guide"),
+ sliding_window_size=ui_params.get("sliding_window_size", 81),
+ sliding_window_overlap=ui_params.get("sliding_window_overlap", 5),
+ sliding_window_overlap_noise=ui_params.get("sliding_window_overlap_noise", 20),
+ sliding_window_discard_last_frames=ui_params.get("sliding_window_discard_last_frames", 0),
+ remove_background_images_ref=ui_params.get("remove_background_images_ref", False),
+ temporal_upsampling=ui_params.get("temporal_upsampling", ""),
+ spatial_upsampling=ui_params.get("spatial_upsampling", ""),
+ RIFLEx_setting=ui_params.get("RIFLEx_setting", 0),
+ slg_switch=ui_params.get("slg_switch", 0),
+ slg_layers=ui_params.get("slg_layers", [9]),
+ slg_start_perc=ui_params.get("slg_start_perc", 10),
+ slg_end_perc=ui_params.get("slg_end_perc", 90),
+ cfg_star_switch=ui_params.get("cfg_star_switch", 0),
+ cfg_zero_step=ui_params.get("cfg_zero_step", -1),
+ prompt_enhancer=ui_params.get("prompt_enhancer", "")
+ )
+
+ print(f"[WGP_GENERATION_DEBUG] WGP generation call completed")
+
+ # Find generated video files
+ generated_video_files = sorted([
+ item for item in Path(temp_output_dir).iterdir()
+ if item.is_file() and item.suffix.lower() == ".mp4"
+ ])
+
+ print(f"[WGP_GENERATION_DEBUG] Found {len(generated_video_files)} video files in output directory")
+
+ if not generated_video_files:
+ print(f"[WGP_GENERATION_DEBUG] ERROR: No .mp4 files found in {temp_output_dir}")
+ return False, f"No video files generated in {temp_output_dir}"
+
+ # Analyze each generated file
+ for i, video_file in enumerate(generated_video_files):
+ from .common_utils import get_video_frame_count_and_fps
+ try:
+ frames, fps = get_video_frame_count_and_fps(str(video_file))
+ file_size = video_file.stat().st_size
+ print(f"[WGP_GENERATION_DEBUG] Generated file {i}: {video_file.name}")
+ print(f"[WGP_GENERATION_DEBUG] Frames: {frames}")
+ print(f"[WGP_GENERATION_DEBUG] FPS: {fps}")
+ print(f"[WGP_GENERATION_DEBUG] Size: {file_size / (1024*1024):.2f} MB")
+ print(f"[WGP_GENERATION_DEBUG] Expected frames: {video_length}")
+ if frames != video_length:
+ print(f"[WGP_GENERATION_DEBUG] ⚠️ FRAME COUNT MISMATCH! Expected {video_length}, got {frames}")
+ except Exception as e:
+ print(f"[WGP_GENERATION_DEBUG] ERROR analyzing {video_file}: {e}")
+
+ # Return the first (or only) generated file
+ final_output = str(generated_video_files[0].resolve())
+ print(f"[WGP_GENERATION_DEBUG] Returning output: {final_output}")
+
+ return True, final_output
+
+ finally:
+ # Restore original save path
+ wgp_mod.save_path = original_wgp_save_path
+
+ except Exception as e:
+ print(f"[WGP_GENERATION_DEBUG] ERROR during generation: {e}")
+ import traceback
+ traceback.print_exc()
+ return False, f"Generation failed: {str(e)}"
\ No newline at end of file
diff --git a/steerable_motion.py b/steerable_motion.py
deleted file mode 100644
index 6f42d8e89..000000000
--- a/steerable_motion.py
+++ /dev/null
@@ -1,310 +0,0 @@
-"""Steerable Motion orchestrator CLI script.
-
-This entry-point parses command-line arguments for the two high-level tasks
-(`travel_between_images` and `different_pose`), initialises logging/debug
-behaviour, ensures the local SQLite `tasks` database exists, and then
-delegates the heavy-lifting to modular handlers living in
-`sm_functions.travel_between_images` and `sm_functions.different_pose`.
-
-The script itself therefore coordinates the overall workflow, keeps global
-state (e.g. DEBUG_MODE), and performs lightweight orchestration rather than
-image/video processing.
-"""
-
-import argparse
-import sqlite3
-import json
-import time
-import uuid
-import os
-from pathlib import Path
-import shutil
-import subprocess
-import tempfile
-import sys
-import traceback
-import shlex # Add shlex import
-from dotenv import load_dotenv
-
-# --- Import from our new sm_functions package ---
-from sm_functions import (
- run_travel_between_images_task,
- run_different_pose_task,
- # Common utilities that steerable_motion.py might directly use (e.g. for init)
- DEBUG_MODE as SM_DEBUG_MODE, # Alias to avoid conflict if main script defines its own DEBUG_MODE
- DEFAULT_DB_TABLE_NAME,
- dprint,
- parse_resolution
-)
-# Expose DEBUG_MODE from common_utils to the global scope of this script
-# This allows dprint within this file (if any) to work as expected.
-# The actual value will be set after parsing args.
-global DEBUG_MODE
-DEBUG_MODE = SM_DEBUG_MODE # Initialize with the value from common_utils
-
-# --- Constants for DB interaction and defaults (specific to steerable_motion.py argparser) ---
-# These were NOT moved to common_utils as they are tied to CLI parsing here.
-DEFAULT_MODEL_NAME = "vace_14B"
-DEFAULT_SEGMENT_FRAMES = 81
-DEFAULT_FPS_HELPERS = 25
-DEFAULT_SEED = 12345
-
-# ----------------------------------------------------
-# Helper: Ensure the SQLite DB and tasks table exist
-# ----------------------------------------------------
-
-def _ensure_db_initialized(db_path_str: str, table_name: str = DEFAULT_DB_TABLE_NAME):
- """Creates the tasks table (and helpful index) in the SQLite DB if it doesn't exist."""
- conn = sqlite3.connect(db_path_str)
- cursor = conn.cursor()
- cursor.execute(
- f"""
- CREATE TABLE IF NOT EXISTS {table_name} (
- task_id TEXT PRIMARY KEY,
- params TEXT NOT NULL,
- task_type TEXT NOT NULL,
- status TEXT NOT NULL DEFAULT 'Queued',
- output_location TEXT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """
- )
- cursor.execute(f"CREATE INDEX IF NOT EXISTS idx_status_created_at ON {table_name} (status, created_at)")
- conn.commit()
- conn.close()
-
-# ----------------------------------------------------
-
-def main():
- # --- Load .env file for potential SQLITE_DB_PATH_ENV ---
- # This is a simple way to allow .env to influence db_path before arg parsing sets defaults.
- # A more robust solution might integrate dotenv earlier or pass it explicitly.
- load_dotenv()
- env_sqlite_db_path = os.getenv("SQLITE_DB_PATH_ENV")
- # --- End .env load ---
-
- common_parser = argparse.ArgumentParser(add_help=False)
- common_parser.add_argument("--resolution", type=str, required=True, help="Output resolution, e.g., '960x544'.")
- common_parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME, help="Model name for headless.py tasks.")
- common_parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Base seed for tasks, incremented per segment/step.")
- common_parser.add_argument("--db_path", type=str, default="tasks.db", help="Path to the SQLite database for headless.py.")
- common_parser.add_argument("--output_dir", type=str, default="./steerable_motion_output", help="Base directory for all outputs.")
- common_parser.add_argument("--fps_helpers", type=int, default=DEFAULT_FPS_HELPERS, help="FPS for generated mask/guide videos.")
- common_parser.add_argument("--output_video_frames", type=int, default=30, help="Number of frames for the final generated video output (e.g., in different_pose). Default 30.")
- common_parser.add_argument("--poll_interval", type=int, default=15, help="Polling interval (seconds) for task status.")
- common_parser.add_argument("--poll_timeout", type=int, default=30 * 60, help="Timeout (seconds) for polling a single task.")
- common_parser.add_argument("--skip_cleanup", action="store_true", help="Skip cleanup of intermediate segment/task directories.")
- common_parser.add_argument("--debug", action="store_true", help="Enable verbose debug logging and prevent cleanup of intermediate files.")
- common_parser.add_argument("--use_causvid_lora", action="store_true", help="Enable CausVid LoRA for video generation tasks.")
- common_parser.add_argument("--upscale_model_name", type=str, default="ltxv_13B", help="Model name for headless.py to use for upscaling tasks (default: ltxv_13B).")
- common_parser.add_argument("--last_frame_duplication", type=int, default=0, help="Number of additional times the last anchor frame is repeated at the end of the guide video. 0 means the last anchor is only the very last frame; 4 means the last anchor frame becomes the last 5 frames of the guide.")
- common_parser.add_argument("--final_image_strength", type=float, default=1.0, help="Strength of the final anchor image when used as a VACE reference (0.0 to 1.0). 1.0 is full strength (opaque), 0.0 is fully transparent. Using this implicitly makes the end anchor a VACE ref. Values outside [0,1] will be clamped.")
- common_parser.add_argument("--initial_image_strength", type=float, default=0.0, help="Strength of the initial anchor image when used as a VACE reference (0.0 to 1.0). 1.0 is full strength (opaque), 0.0 is fully transparent. Similar to --final_image_strength but for the start anchor. Values outside [0,1] will be clamped.")
- common_parser.add_argument(
- "--fade_in_duration",
- type=str,
- default='{"low_point": 0.0, "high_point": 1.0, "curve_type": "ease_in_out", "duration_factor": 0.0}',
- help="JSON string for fade-in settings of the end anchor in the guide video. Example: '{\"low_point\": 0.0, \"high_point\": 1.0, \"curve_type\": \"ease_in_out\", \"duration_factor\": 0.5}'. 'duration_factor' is proportion of available frames (0.0-1.0)."
- )
- common_parser.add_argument(
- "--fade_out_duration",
- type=str,
- default='{"low_point": 0.0, "high_point": 1.0, "curve_type": "ease_in_out", "duration_factor": 0.0}',
- help="JSON string for fade-out settings of the start anchor in the guide video. Example: '{\"low_point\": 0.0, \"high_point\": 1.0, \"curve_type\": \"ease_in_out\", \"duration_factor\": 0.5}'. 'duration_factor' is proportion of available frames (0.0-1.0)."
- )
- common_parser.add_argument(
- "--subsequent_starting_strength_adjustment",
- type=float,
- default=0.0,
- help="Float value to adjust the starting strength of subsequent/overlapped segment's initial frames. Applied after fade-out calculations. E.g., -0.2 to reduce strength, 0.1 to increase."
- )
- common_parser.add_argument(
- "--desaturate_subsequent_starting_frames",
- type=float,
- default=0.0,
- help="Float value (0.0 to 1.0) to desaturate the initial frames of subsequent/overlapped segments. 0.0 is no desaturation, 1.0 is fully grayscale. Applied after strength adjustments."
- )
- common_parser.add_argument(
- "--adjust_brightness_subsequent_starting_frames",
- type=float,
- default=0.0,
- help="Float value to adjust brightness of initial frames of subsequent/overlapped segments. Positive values make it darker (e.g., 0.1 for 10% darker), negative values make it brighter (e.g., -0.1 for 10% brighter). 0.0 means no change. Applied after strength and desaturation."
- )
- common_parser.add_argument("--execution_engine", type=str, default="wgp", choices=["wgp", "comfyui"], help="The execution engine to use for generation tasks (wgp or comfyui).")
- common_parser.add_argument(
- "--cfg_star_switch",
- type=int,
- default=0, # Default to off
- choices=[0, 1],
- help="Enable CFG* (Zero-Star) technique (0 for Off, 1 for On). Modulates CFG application."
- )
- common_parser.add_argument(
- "--cfg_zero_step",
- type=int,
- default=-1, # Default to -1 (inactive or wgp.py's internal default)
- help="Number of initial steps (0-indexed) where conditional guidance is zeroed if CFG* is active. -1 might mean inactive or use wgp.py default."
- )
- common_parser.add_argument(
- "--params_json_str",
- type=str,
- default=None,
- help="JSON string of additional parameters to be merged into the task payload for headless.py. E.g., '{\"guidance_scale\": 7.5, \"num_inference_steps\": 20}'"
- )
- common_parser.add_argument(
- "--after_first_post_generation_saturation",
- type=float,
- default=None,
- help="Saturation level (eq=saturation=) applied via FFmpeg to every generated video segment AFTER the first one. 1.0 means no change. Values <1.0 desaturate, >1.0 increase saturation. If omitted, no saturation change is applied."
- )
-
- parser = argparse.ArgumentParser(description="Generates steerable motion videos or performs other image tasks using headless.py.")
- subparsers = parser.add_subparsers(title="tasks", dest="task", required=True, help="Specify the task to perform.")
-
- parser_travel = subparsers.add_parser("travel_between_images", help="Generate video interpolating between multiple anchor images.", parents=[common_parser])
- parser_travel.add_argument("--input_images", nargs='+', required=True, help="List of input anchor image paths.")
- parser_travel.add_argument("--base_prompts", nargs='+', required=True, help="List of base prompts for each segment.")
- parser_travel.add_argument("--negative_prompts", nargs='+', default=[""], help="List of negative prompts for each segment, or a single one for all. Defaults to an empty string for no negative prompt.")
- parser_travel.add_argument("--frame_overlap", nargs='+', type=int, default=[16], help="Number of frames to overlap between segments. Can be one value for all, or one per segment (number of images - 1, or number of images if --continue_from_video is used).")
- parser_travel.add_argument("--segment_frames", nargs='+', type=int, default=[DEFAULT_SEGMENT_FRAMES], help="Total frames for each headless.py segment. Can be one value for all, or one per segment (number of images - 1, or number of images if --continue_from_video is used).")
- parser_travel.add_argument("--upscale_factor", type=float, default=0.0, help="Factor to upscale final video (e.g., 1.5, 2.0). Use 0.0 or 1.0 to disable. Uses model specified by --upscale_model_name.")
- parser_travel.add_argument("--continue_from_video", type=str, default=None, help="Path or URL to a video to continue from. If provided, this video acts as the first segment.")
-
- parser_pose = subparsers.add_parser("different_pose", help="Generate an image with a different pose based on an input image and prompt.", parents=[common_parser])
- parser_pose.add_argument("--input_image", type=str, required=True, help="Path to the single input image.")
- parser_pose.add_argument("--prompt", type=str, required=True, help="Prompt to guide the generation.")
-
- args = parser.parse_args()
-
- # Set the global DEBUG_MODE based on parsed arguments
- # This ensures dprint in this file and in common_utils (via its DEBUG_MODE) honors the CLI flag
- global DEBUG_MODE
- from sm_functions import common_utils as sm_common_utils
- DEBUG_MODE = args.debug
- sm_common_utils.DEBUG_MODE = args.debug # Explicitly set it in the imported module as well
-
- executed_command_str = None
- if DEBUG_MODE:
- script_name = Path(sys.argv[0]).name
- # Properly quote each argument, especially those with spaces
- quoted_args = [shlex.quote(arg) for arg in sys.argv[1:]]
- remaining_args = " ".join(quoted_args)
- executed_command_str = f"python {script_name} {remaining_args}"
- dprint(f"Executed command: {executed_command_str}")
-
- try:
- # parse_resolution is now imported from sm_functions (which gets it from common_utils)
- parsed_resolution_val = parse_resolution(args.resolution)
- except ValueError as e:
- parser.error(str(e))
-
- # --- Argument Validation specific to travel_between_images ---
- if args.task == "travel_between_images":
- if args.continue_from_video:
- if not args.input_images:
- parser.error("For 'travel_between_images' with --continue_from_video, at least one input image is required.")
- num_segments_expected = len(args.input_images)
- else:
- if len(args.input_images) < 2:
- # This check is somewhat redundant with num_segments_expected > 0 later, but good for clarity.
- parser.error("For 'travel_between_images' (without --continue_from_video), at least two input images are required.")
- num_segments_expected = len(args.input_images) - 1
-
- if num_segments_expected <= 0 and args.input_images: # Allow num_segments_expected to be 0 if no input_images for continue_from_video (already errored)
- pass # Error already handled or will be if num_segments_expected is critical for a check below and is zero.
- elif num_segments_expected <=0 :
- parser.error("No segments to generate based on input images and --continue_from_video flag.")
-
- # Validate --segment_frames
- if len(args.segment_frames) > 1 and len(args.segment_frames) != num_segments_expected:
- parser.error(
- f"Number of --segment_frames values ({len(args.segment_frames)}) must match the number of segments "
- f"({num_segments_expected}), or be a single value."
- )
- for i, val in enumerate(args.segment_frames):
- if val <= 0:
- parser.error(f"--segment_frames value at index {i} ('{val}') must be positive.")
-
- # Validate --frame_overlap
- if len(args.frame_overlap) > 1 and len(args.frame_overlap) != num_segments_expected:
- parser.error(
- f"Number of --frame_overlap values ({len(args.frame_overlap)}) must match the number of segments "
- f"({num_segments_expected}), or be a single value."
- )
- for i, val in enumerate(args.frame_overlap):
- if val < 0:
- parser.error(f"--frame_overlap value at index {i} ('{val}') cannot be negative.")
-
- # Validate --base_prompts
- if len(args.base_prompts) > 1 and len(args.base_prompts) != num_segments_expected:
- parser.error(
- f"Number of --base_prompts values ({len(args.base_prompts)}) must match the number of segments "
- f"({num_segments_expected}), or be a single value."
- )
-
- # Validate --negative_prompts
- if len(args.negative_prompts) > 1 and len(args.negative_prompts) != num_segments_expected:
- parser.error(
- f"Number of --negative_prompts values ({len(args.negative_prompts)}) must match the number of segments "
- f"({num_segments_expected}), or be a single value."
- )
-
- # Cross-validation: segment_frames vs frame_overlap for each segment
- if num_segments_expected > 0: # Only run if there are segments to validate
- for i in range(num_segments_expected):
- # Determine current segment's frame and overlap values
- current_segment_frames_val = args.segment_frames[0] if len(args.segment_frames) == 1 else args.segment_frames[i]
- current_frame_overlap_val = args.frame_overlap[0] if len(args.frame_overlap) == 1 else args.frame_overlap[i]
-
- if current_frame_overlap_val > 0 and current_segment_frames_val <= current_frame_overlap_val:
- start_image_name_for_error = ""
- if args.continue_from_video:
- if i == 0:
- start_image_name_for_error = "the continued video"
- else:
- start_image_name_for_error = f"image {args.input_images[i-1]}"
- else:
- start_image_name_for_error = f"image {args.input_images[i]}"
-
- end_image_name_for_error = f"image {args.input_images[i]}" if args.continue_from_video else f"image {args.input_images[i+1]}"
-
- parser.error(
- f"For segment {i+1} (transitioning from {start_image_name_for_error} to {end_image_name_for_error}), --segment_frames ({current_segment_frames_val}) "
- f"must be greater than --frame_overlap ({current_frame_overlap_val}) when overlap is used."
- )
- # --- End Argument Validation ---
-
- main_output_dir = Path(args.output_dir)
- main_output_dir.mkdir(parents=True, exist_ok=True)
- dprint(f"Main output directory: {main_output_dir.resolve()}")
-
- # Determine the database path: .env variable, then CLI arg, then default.
- # args.db_path already holds the CLI value or its default ("tasks.db").
- db_file_path_str = env_sqlite_db_path if env_sqlite_db_path else args.db_path
- dprint(f"Using database path: {Path(db_file_path_str).resolve()}")
-
- try:
- _ensure_db_initialized(db_file_path_str, DEFAULT_DB_TABLE_NAME)
- except Exception as e_db_init:
- print(f"Fatal: Could not initialize database: {e_db_init}")
- return 1
-
- # Propagate the debug flag to task modules so their local DEBUG_MODE copies stay in sync
- import sm_functions.travel_between_images as _travel_mod
- import sm_functions.different_pose as _diff_pose_mod
- _travel_mod.DEBUG_MODE = args.debug
- _diff_pose_mod.DEBUG_MODE = args.debug
-
- exit_code = 0
- if args.task == "travel_between_images":
- exit_code = run_travel_between_images_task(args, args, parsed_resolution_val, main_output_dir, db_file_path_str, executed_command_str)
- elif args.task == "different_pose":
- exit_code = run_different_pose_task(args, args, parsed_resolution_val, main_output_dir, db_file_path_str, executed_command_str)
- else:
- parser.error(f"Unknown task: {args.task}")
-
- print(f"\nSteerable motion script finished for task '{args.task}'. Exit code: {exit_code}")
- return exit_code
-
-if __name__ == "__main__":
- sys.exit(main())
\ No newline at end of file
diff --git a/supabase/functions/claim-next-task/index.ts b/supabase/functions/claim-next-task/index.ts
new file mode 100644
index 000000000..31e35fda4
--- /dev/null
+++ b/supabase/functions/claim-next-task/index.ts
@@ -0,0 +1,396 @@
+import { serve } from "https://deno.land/std@0.224.0/http/server.ts";
+import { createClient } from "https://esm.sh/@supabase/supabase-js@2.39.7";
+import { createHash } from "https://deno.land/std@0.224.0/crypto/mod.ts";
+
+/**
+ * Edge function: claim-next-task
+ *
+ * Claims the next queued task atomically.
+ * - Service-role key: claims any task across all users
+ * - User token: claims only tasks for that specific user
+ *
+ * POST /functions/v1/claim-next-task
+ * Headers: Authorization: Bearer
+ * Body: {} (empty JSON)
+ *
+ * Returns:
+ * - 200 OK with task data
+ * - 204 No Content if no tasks available
+ * - 401 Unauthorized if no valid token
+ * - 403 Forbidden if token invalid or user not found
+ * - 500 Internal Server Error
+ */
+serve(async (req) => {
+ // Only accept POST requests
+ if (req.method !== "POST") {
+ return new Response("Method not allowed", { status: 405 });
+ }
+
+ // Extract authorization header
+ const authHeader = req.headers.get("Authorization");
+ if (!authHeader?.startsWith("Bearer ")) {
+ return new Response("Missing or invalid Authorization header", { status: 401 });
+ }
+
+ const token = authHeader.slice(7); // Remove "Bearer " prefix
+ const serviceKey = Deno.env.get("SUPABASE_SERVICE_ROLE_KEY");
+ const supabaseUrl = Deno.env.get("SUPABASE_URL");
+
+ if (!serviceKey || !supabaseUrl) {
+ console.error("Missing required environment variables");
+ return new Response("Server configuration error", { status: 500 });
+ }
+
+ // Parse request body to get worker_id if provided
+ let requestBody: any = {};
+ try {
+ const bodyText = await req.text();
+ if (bodyText) {
+ requestBody = JSON.parse(bodyText);
+ }
+ } catch (e) {
+ console.log("No valid JSON body provided, using default worker_id");
+ }
+
+ // Create admin client for database operations
+ const supabaseAdmin = createClient(supabaseUrl, serviceKey);
+
+ let callerId: string | null = null;
+ let isServiceRole = false;
+
+ // 1) Check if token matches service-role key directly
+ if (token === serviceKey) {
+ isServiceRole = true;
+ console.log("Direct service-role key match");
+ }
+
+ // 2) If not service key, try to decode as JWT and check role
+ if (!isServiceRole) {
+ try {
+ const parts = token.split(".");
+ if (parts.length === 3) {
+ // It's a JWT - decode and check role
+ const payloadB64 = parts[1];
+ const padded = payloadB64 + "=".repeat((4 - (payloadB64.length % 4)) % 4);
+ const payload = JSON.parse(atob(padded));
+
+ // Check for service role in various claim locations
+ const role = payload.role || payload.app_metadata?.role;
+ if (["service_role", "supabase_admin"].includes(role)) {
+ isServiceRole = true;
+ console.log("JWT has service-role/admin role");
+ }
+ // Don't extract user ID from JWT - always look it up in user_api_token table
+ }
+ } catch (e) {
+ // Not a valid JWT - will be treated as PAT
+ console.log("Token is not a valid JWT, treating as PAT");
+ }
+ }
+
+ // 3) USER TOKEN PATH - ALWAYS resolve callerId via user_api_token table
+ if (!isServiceRole) {
+ console.log("Looking up token in user_api_token table...");
+
+ try {
+ // Query user_api_tokens table to find user
+ const { data, error } = await supabaseAdmin
+ .from("user_api_tokens")
+ .select("user_id")
+ .eq("token", token)
+ .single();
+
+ if (error || !data) {
+ console.error("Token lookup failed:", error);
+ return new Response("Invalid or expired token", { status: 403 });
+ }
+
+ callerId = data.user_id;
+ console.log(`Token resolved to user ID: ${callerId}`);
+
+ // Debug: Check user's projects and tasks
+ const { data: userProjects } = await supabaseAdmin
+ .from("projects")
+ .select("id, name")
+ .eq("user_id", callerId);
+
+ console.log(`DEBUG: User ${callerId} owns ${userProjects?.length || 0} projects`);
+
+ if (userProjects && userProjects.length > 0) {
+ const projectIds = userProjects.map(p => p.id);
+ const { data: userTasks } = await supabaseAdmin
+ .from("tasks")
+ .select("id, status, project_id, task_type, created_at")
+ .in("project_id", projectIds);
+
+ console.log(`DEBUG: Found ${userTasks?.length || 0} tasks across user's projects`);
+ if (userTasks && userTasks.length > 0) {
+ const queuedTasks = userTasks.filter(t => t.status === "Queued");
+ console.log(`DEBUG: ${queuedTasks.length} tasks are in 'Queued' status`);
+ console.log("DEBUG: Sample tasks:", JSON.stringify(userTasks.slice(0, 3), null, 2));
+
+ // Show unique status values to debug enum
+ const uniqueStatuses = [...new Set(userTasks.map(t => t.status))];
+ console.log(`DEBUG: Unique status values: ${JSON.stringify(uniqueStatuses)}`);
+ }
+ } else {
+ console.log(`DEBUG: User ${callerId} has no projects - cannot claim any tasks`);
+ }
+ } catch (e) {
+ console.error("Error querying user_api_token:", e);
+ return new Response("Token validation failed", { status: 403 });
+ }
+ }
+
+ // Handle worker_id based on token type
+ let workerId: string | null = null;
+ if (isServiceRole) {
+ // Service role: use provided worker_id or generate one
+ workerId = requestBody.worker_id || `edge_${crypto.randomUUID()}`;
+ console.log(`Service role using worker_id: ${workerId}`);
+ } else {
+ // User/PAT: no worker_id needed (individual users don't have worker IDs)
+ console.log(`User token: not using worker_id`);
+ }
+
+ try {
+ // Execute database operation based on token type
+ let claimResponse;
+
+ if (isServiceRole) {
+ // Service role: claim any available task from any project atomically
+ console.log("Service role: Executing atomic find-and-claim for all tasks");
+
+ const serviceUpdatePayload = {
+ status: "In Progress" as const,
+ worker_id: workerId, // Service role gets worker_id for tracking
+ updated_at: new Date().toISOString()
+ };
+
+ // Get all queued tasks and manually check dependencies
+ const { data: queuedTasks, error: findError } = await supabaseAdmin
+ .from("tasks")
+ .select("id, params, task_type, project_id, created_at, dependant_on")
+ .eq("status", "Queued")
+ .order("created_at", { ascending: true });
+
+ if (findError) {
+ throw findError;
+ }
+
+ // Manual dependency checking for service role
+ const readyTasks: any[] = [];
+ for (const task of (queuedTasks || [])) {
+ if (!task.dependant_on) {
+ // No dependency - task is ready
+ readyTasks.push(task);
+ } else {
+ // Check if dependency is complete
+ const { data: depData } = await supabaseAdmin
+ .from("tasks")
+ .select("status")
+ .eq("id", task.dependant_on)
+ .single();
+
+ if (depData?.status === "Complete") {
+ readyTasks.push(task);
+ }
+ }
+ }
+
+ console.log(`Service role dependency check: ${queuedTasks?.length || 0} queued, ${readyTasks.length} ready`);
+
+ let updateData: any = null;
+ let updateError: any = null;
+
+ if (readyTasks.length > 0) {
+ const taskToTake = readyTasks[0];
+
+ // Atomically claim the first eligible task
+ const result = await supabaseAdmin
+ .from("tasks")
+ .update(serviceUpdatePayload)
+ .eq("id", taskToTake.id)
+ .eq("status", "Queued") // Double-check it's still queued
+ .select()
+ .single();
+
+ updateData = result.data;
+ updateError = result.error;
+ } else {
+ // No eligible tasks found - set error to indicate no rows
+ updateError = { code: "PGRST116", message: "No eligible tasks found" };
+ }
+
+ console.log(`Service role atomic claim result - error: ${updateError?.message || updateError?.code || 'none'}, data: ${updateData ? 'claimed task ' + updateData.id : 'no data'}`);
+
+ if (updateError && updateError.code !== "PGRST116") { // PGRST116 = no rows
+ console.error("Service role atomic claim failed:", updateError);
+ throw updateError;
+ }
+
+ if (updateData) {
+ console.log(`Service role successfully claimed task ${updateData.id} atomically`);
+ rpcResponse = {
+ data: [{
+ task_id_out: updateData.id,
+ params_out: updateData.params,
+ task_type_out: updateData.task_type,
+ project_id_out: updateData.project_id
+ }],
+ error: null
+ };
+ } else {
+ console.log("Service role: No queued tasks available for atomic claiming");
+ rpcResponse = { data: [], error: null };
+ }
+ } else {
+ // User token: use the user-specific claim function
+ console.log(`Claiming task for user ${callerId}...`);
+
+ try {
+ // Try the user-specific function first
+ // First get user's project IDs, then query tasks
+ const { data: userProjects } = await supabaseAdmin
+ .from("projects")
+ .select("id")
+ .eq("user_id", callerId);
+
+ if (!userProjects || userProjects.length === 0) {
+ console.log("User has no projects");
+ rpcResponse = { data: [], error: null };
+ } else {
+ const projectIds = userProjects.map(p => p.id);
+ console.log(`DEBUG: Claiming from ${projectIds.length} project IDs: [${projectIds.slice(0, 3).join(', ')}...]`);
+
+ if (projectIds.length === 0) {
+ console.log("No project IDs to search - user has projects but they have no IDs?");
+ rpcResponse = { data: [], error: null };
+ } else {
+ // Get queued tasks for user projects and manually check dependencies
+ console.log(`DEBUG: Finding eligible tasks with dependency checking for ${projectIds.length} projects`);
+
+ const { data: userQueuedTasks, error: userFindError } = await supabaseAdmin
+ .from("tasks")
+ .select("id, params, task_type, project_id, created_at, dependant_on")
+ .eq("status", "Queued")
+ .in("project_id", projectIds)
+ .order("created_at", { ascending: true });
+
+ if (userFindError) {
+ throw userFindError;
+ }
+
+ // Manual dependency checking for user tasks
+ const userReadyTasks: any[] = [];
+ for (const task of (userQueuedTasks || [])) {
+ if (!task.dependant_on) {
+ // No dependency - task is ready
+ userReadyTasks.push(task);
+ } else {
+ // Check if dependency is complete
+ const { data: depData } = await supabaseAdmin
+ .from("tasks")
+ .select("status")
+ .eq("id", task.dependant_on)
+ .single();
+
+ if (depData?.status === "Complete") {
+ userReadyTasks.push(task);
+ }
+ }
+ }
+
+ console.log(`DEBUG: User dependency check: ${userQueuedTasks?.length || 0} queued, ${userReadyTasks.length} ready`);
+
+ const updatePayload: any = {
+ status: "In Progress",
+ updated_at: new Date().toISOString()
+ // Note: No worker_id for user claims - individual users don't have worker IDs
+ };
+
+ let updateData: any = null;
+ let updateError: any = null;
+
+ if (userReadyTasks.length > 0) {
+ const taskToTake = userReadyTasks[0];
+
+ // Atomically claim the first eligible task
+ const result = await supabaseAdmin
+ .from("tasks")
+ .update(updatePayload)
+ .eq("id", taskToTake.id)
+ .eq("status", "Queued") // Double-check it's still queued
+ .select()
+ .single();
+
+ updateData = result.data;
+ updateError = result.error;
+ } else {
+ // No eligible tasks found
+ updateError = { code: "PGRST116", message: "No eligible tasks found for user" };
+ }
+
+ console.log(`DEBUG: User atomic claim result - error: ${updateError?.message || updateError?.code || 'none'}, data: ${updateData ? 'claimed task ' + updateData.id : 'no data'}`);
+
+ if (updateError && updateError.code !== "PGRST116") { // PGRST116 = no rows
+ console.error("User atomic claim failed:", updateError);
+ throw updateError;
+ }
+
+ if (updateData) {
+ // Successfully claimed atomically
+ console.log(`Successfully claimed task ${updateData.id} atomically for user`);
+ rpcResponse = {
+ data: [{
+ task_id_out: updateData.id,
+ params_out: updateData.params,
+ task_type_out: updateData.task_type,
+ project_id_out: updateData.project_id
+ }],
+ error: null
+ };
+ } else {
+ // No tasks available or all were claimed by others
+ console.log("No queued tasks available for user atomic claiming");
+ rpcResponse = { data: [], error: null };
+ }
+ }
+ }
+ } catch (e) {
+ console.error("Error claiming user task:", e);
+ rpcResponse = { data: [], error: null };
+ }
+ }
+
+ // Check RPC response
+ if (claimResponse.error) {
+ console.error("RPC error:", rpcResponse.error);
+ return new Response(`Database error: ${rpcResponse.error.message}`, { status: 500 });
+ }
+
+ // Check if we got a task
+ if (!claimResponse.data || claimResponse.data.length === 0) {
+ console.log("No queued tasks available");
+ return new Response(null, { status: 204 });
+ }
+
+ const task = rpcResponse.data[0];
+ console.log(`Successfully claimed task ${task.task_id_out}`);
+
+ // Return the task data
+ return new Response(JSON.stringify({
+ task_id: task.task_id_out,
+ params: task.params_out,
+ task_type: task.task_type_out,
+ project_id: task.project_id_out
+ }), {
+ status: 200,
+ headers: { "Content-Type": "application/json" }
+ });
+
+ } catch (error) {
+ console.error("Unexpected error:", error);
+ return new Response(`Internal server error: ${error.message}`, { status: 500 });
+ }
+});
\ No newline at end of file
diff --git a/supabase/functions/complete-task/index.ts b/supabase/functions/complete-task/index.ts
new file mode 100644
index 000000000..8f4ebd2b0
--- /dev/null
+++ b/supabase/functions/complete-task/index.ts
@@ -0,0 +1,336 @@
+// deno-lint-ignore-file
+// @ts-ignore
+// eslint-disable-next-line @typescript-eslint/no-explicit-any
+declare const Deno: any;
+import { serve } from "https://deno.land/std@0.224.0/http/server.ts";
+import { createClient } from "https://esm.sh/@supabase/supabase-js@2.39.7";
+
+/**
+ * Edge function: complete-task
+ *
+ * Completes a task by uploading file data and updating task status.
+ * - Service-role key: can complete any task
+ * - User token: can only complete tasks they own
+ *
+ * POST /functions/v1/complete-task
+ * Headers: Authorization: Bearer
+ * Body: { task_id, file_data: "base64...", filename: "image.png" }
+ *
+ * Returns:
+ * - 200 OK with success data
+ * - 401 Unauthorized if no valid token
+ * - 403 Forbidden if token invalid or user not authorized
+ * - 500 Internal Server Error
+ */
+serve(async (req) => {
+ if (req.method !== "POST") {
+ return new Response("Method not allowed", { status: 405 });
+ }
+
+ let body;
+ try {
+ body = await req.json();
+ } catch (e) {
+ return new Response("Invalid JSON body", { status: 400 });
+ }
+
+ const { task_id, file_data, filename } = body;
+
+ console.log(`[COMPLETE-TASK-DEBUG] Received request with task_id type: ${typeof task_id}, value: ${JSON.stringify(task_id)}`);
+ console.log(`[COMPLETE-TASK-DEBUG] Body keys: ${Object.keys(body)}`);
+
+ if (!task_id || !file_data || !filename) {
+ return new Response("task_id, file_data (base64), and filename required", { status: 400 });
+ }
+
+ // Convert task_id to string early to avoid UUID casting issues
+ const taskIdString = String(task_id);
+ console.log(`[COMPLETE-TASK-DEBUG] Converted task_id to string: ${taskIdString}`);
+
+ // Extract authorization header
+ const authHeader = req.headers.get("Authorization");
+ if (!authHeader?.startsWith("Bearer ")) {
+ return new Response("Missing or invalid Authorization header", { status: 401 });
+ }
+
+ const token = authHeader.slice(7); // Remove "Bearer " prefix
+ const serviceKey = Deno.env.get("SUPABASE_SERVICE_ROLE_KEY");
+ const supabaseUrl = Deno.env.get("SUPABASE_URL");
+
+ if (!serviceKey || !supabaseUrl) {
+ console.error("Missing required environment variables");
+ return new Response("Server configuration error", { status: 500 });
+ }
+
+ // Create admin client for database operations
+ const supabaseAdmin = createClient(supabaseUrl, serviceKey);
+
+ let callerId: string | null = null;
+ let isServiceRole = false;
+
+ // 1) Check if token matches service-role key directly
+ if (token === serviceKey) {
+ isServiceRole = true;
+ console.log("Direct service-role key match");
+ }
+
+ // 2) If not service key, try to decode as JWT and check role
+ if (!isServiceRole) {
+ try {
+ const parts = token.split(".");
+ if (parts.length === 3) {
+ // It's a JWT - decode and check role
+ const payloadB64 = parts[1];
+ const padded = payloadB64 + "=".repeat((4 - (payloadB64.length % 4)) % 4);
+ const payload = JSON.parse(atob(padded));
+
+ // Check for service role in various claim locations
+ const role = payload.role || payload.app_metadata?.role;
+ if (["service_role", "supabase_admin"].includes(role)) {
+ isServiceRole = true;
+ console.log("JWT has service-role/admin role");
+ }
+ // Don't extract user ID from JWT - always look it up in user_api_token table
+ }
+ } catch (e) {
+ // Not a valid JWT - will be treated as PAT
+ console.log("Token is not a valid JWT, treating as PAT");
+ }
+ }
+
+ // 3) USER TOKEN PATH - ALWAYS resolve callerId via user_api_token table
+ if (!isServiceRole) {
+ console.log("Looking up token in user_api_token table...");
+
+ try {
+ // Query user_api_tokens table to find user
+ const { data, error } = await supabaseAdmin
+ .from("user_api_tokens")
+ .select("user_id")
+ .eq("token", token)
+ .single();
+
+ if (error || !data) {
+ console.error("Token lookup failed:", error);
+ return new Response("Invalid or expired token", { status: 403 });
+ }
+
+ callerId = data.user_id;
+ console.log(`Token resolved to user ID: ${callerId}`);
+ } catch (e) {
+ console.error("Error querying user_api_token:", e);
+ return new Response("Token validation failed", { status: 403 });
+ }
+ }
+
+ try {
+ // 4) If user token, verify task ownership
+ if (!isServiceRole && callerId) {
+ console.log(`[COMPLETE-TASK-DEBUG] Verifying task ${taskIdString} belongs to user ${callerId}...`);
+ console.log(`[COMPLETE-TASK-DEBUG] taskIdString type: ${typeof taskIdString}, value: ${taskIdString}`);
+
+ const { data: taskData, error: taskError } = await supabaseAdmin
+ .from("tasks")
+ .select("project_id")
+ .eq("id", taskIdString)
+ .single();
+
+ if (taskError) {
+ console.error("Task lookup error:", taskError);
+ return new Response("Task not found", { status: 404 });
+ }
+
+ // Check if user owns the project that this task belongs to
+ const { data: projectData, error: projectError } = await supabaseAdmin
+ .from("projects")
+ .select("user_id")
+ .eq("id", taskData.project_id)
+ .single();
+
+ if (projectError) {
+ console.error("Project lookup error:", projectError);
+ return new Response("Project not found", { status: 404 });
+ }
+
+ if (projectData.user_id !== callerId) {
+ console.error(`Task ${taskIdString} belongs to project ${taskData.project_id} owned by ${projectData.user_id}, not user ${callerId}`);
+ return new Response("Forbidden: Task does not belong to user", { status: 403 });
+ }
+
+ console.log(`Task ${taskIdString} ownership verified: user ${callerId} owns project ${taskData.project_id}`);
+ }
+
+ // 5) Decode the base64 file data
+ const fileBuffer = Uint8Array.from(atob(file_data), c => c.charCodeAt(0));
+
+ // 6) Determine the storage path
+ let userId: string;
+ if (isServiceRole) {
+ // For service role, we need to determine the appropriate user folder
+ // Get the task to find which project (and user) it belongs to
+ console.log(`[COMPLETE-TASK-DEBUG] Service role - looking up task ${taskIdString} for storage path determination`);
+ console.log(`[COMPLETE-TASK-DEBUG] taskIdString type: ${typeof taskIdString}, value: ${taskIdString}`);
+
+ const { data: taskData, error: taskError } = await supabaseAdmin
+ .from("tasks")
+ .select("project_id")
+ .eq("id", taskIdString)
+ .single();
+
+ if (taskError) {
+ console.error("Task lookup error for storage path:", taskError);
+ return new Response("Task not found", { status: 404 });
+ }
+
+ // Get the project owner
+ const { data: projectData, error: projectError } = await supabaseAdmin
+ .from("projects")
+ .select("user_id")
+ .eq("id", taskData.project_id)
+ .single();
+
+ if (projectError) {
+ console.error("Project lookup error for storage path:", projectError);
+ // Fallback to system folder if we can't determine owner
+ userId = 'system';
+ } else {
+ userId = projectData.user_id;
+ }
+ console.log(`Service role storing file for task ${taskIdString} in user ${userId}'s folder`);
+ } else {
+ // For user tokens, use the authenticated user's ID
+ userId = callerId!;
+ }
+
+ const objectPath = `${userId}/${filename}`;
+
+ // 7) Upload to Supabase Storage
+ const { data: uploadData, error: uploadError } = await supabaseAdmin.storage
+ .from('image_uploads')
+ .upload(objectPath, fileBuffer, {
+ contentType: getContentType(filename),
+ upsert: true
+ });
+
+ if (uploadError) {
+ console.error("Storage upload error:", uploadError);
+ return new Response(`Storage upload failed: ${uploadError.message}`, { status: 500 });
+ }
+
+ // 8) Get the public URL
+ const { data: urlData } = supabaseAdmin.storage
+ .from('image_uploads')
+ .getPublicUrl(objectPath);
+
+ const publicUrl = urlData.publicUrl;
+
+ // 9) Update the database with the public URL
+ console.log(`[COMPLETE-TASK-DEBUG] Updating task ${taskIdString} to Complete status`);
+
+ const { error: dbError } = await supabaseAdmin
+ .from("tasks")
+ .update({
+ status: "Complete",
+ output_location: publicUrl,
+ generation_processed_at: new Date().toISOString()
+ })
+ .eq("id", taskIdString)
+ .eq("status", "In Progress");
+
+ if (dbError) {
+ console.error("[COMPLETE-TASK-DEBUG] Database update error:", dbError);
+ // If DB update fails, we should clean up the uploaded file
+ await supabaseAdmin.storage.from('image_uploads').remove([objectPath]);
+ return new Response(`Database update failed: ${dbError.message}`, { status: 500 });
+ }
+
+ console.log(`[COMPLETE-TASK-DEBUG] Database update successful for task ${taskIdString}`);
+
+ // 10) Check if this task completes an orchestrator workflow
+ try {
+ // Get the task details to check if it's a final task in an orchestrator workflow
+ console.log(`[COMPLETE-TASK-DEBUG] Checking orchestrator workflow for task ${taskIdString}`);
+ console.log(`[COMPLETE-TASK-DEBUG] taskIdString type: ${typeof taskIdString}, value: ${taskIdString}`);
+
+ const { data: taskData, error: taskError } = await supabaseAdmin
+ .from("tasks")
+ .select("task_type, params")
+ .eq("id", taskIdString)
+ .single();
+
+ if (!taskError && taskData) {
+ const { task_type, params } = taskData;
+
+ // Check if this is a final task that should complete an orchestrator
+ const isFinalTask = (
+ task_type === "travel_stitch" ||
+ task_type === "dp_final_gen"
+ );
+
+ if (isFinalTask && params?.orchestrator_task_id_ref) {
+ console.log(`[COMPLETE-TASK-DEBUG] Task ${taskIdString} is a final ${task_type} task. Marking orchestrator ${params.orchestrator_task_id_ref} as complete.`);
+
+ // Update the orchestrator task to Complete status with the same output location
+ const orchestratorIdString = String(params.orchestrator_task_id_ref);
+ console.log(`[COMPLETE-TASK-DEBUG] Orchestrator ID string: ${orchestratorIdString}, type: ${typeof orchestratorIdString}`);
+
+ const { error: orchError } = await supabaseAdmin
+ .from("tasks")
+ .update({
+ status: "Complete",
+ output_location: publicUrl,
+ generation_processed_at: new Date().toISOString()
+ })
+ .eq("id", orchestratorIdString)
+ .eq("status", "In Progress"); // Only update if still in progress
+
+ if (orchError) {
+ console.error(`[COMPLETE-TASK-DEBUG] Failed to update orchestrator ${params.orchestrator_task_id_ref}:`, orchError);
+ console.error(`[COMPLETE-TASK-DEBUG] Orchestrator error details:`, JSON.stringify(orchError, null, 2));
+ // Don't fail the whole request, just log the error
+ } else {
+ console.log(`[COMPLETE-TASK-DEBUG] Successfully marked orchestrator ${params.orchestrator_task_id_ref} as complete.`);
+ }
+ }
+ }
+ } catch (orchCheckError) {
+ // Don't fail the main request if orchestrator check fails
+ console.error("Error checking for orchestrator completion:", orchCheckError);
+ }
+
+ console.log(`[COMPLETE-TASK-DEBUG] Successfully completed task ${taskIdString} by ${isServiceRole ? 'service-role' : `user ${callerId}`}`);
+
+ const responseData = {
+ success: true,
+ public_url: publicUrl,
+ message: "Task completed and file uploaded successfully"
+ };
+ console.log(`[COMPLETE-TASK-DEBUG] Returning success response: ${JSON.stringify(responseData)}`);
+
+ return new Response(JSON.stringify(responseData), {
+ status: 200,
+ headers: { "Content-Type": "application/json" }
+ });
+
+ } catch (error) {
+ console.error("[COMPLETE-TASK-DEBUG] Edge function error:", error);
+ console.error("[COMPLETE-TASK-DEBUG] Error stack:", error.stack);
+ console.error("[COMPLETE-TASK-DEBUG] Error details:", JSON.stringify(error, null, 2));
+ return new Response(`Internal error: ${error.message}`, { status: 500 });
+ }
+});
+
+function getContentType(filename: string): string {
+ const ext = filename.toLowerCase().split('.').pop();
+ switch (ext) {
+ case 'png': return 'image/png';
+ case 'jpg':
+ case 'jpeg': return 'image/jpeg';
+ case 'gif': return 'image/gif';
+ case 'webp': return 'image/webp';
+ case 'mp4': return 'video/mp4';
+ case 'webm': return 'video/webm';
+ case 'mov': return 'video/quicktime';
+ default: return 'application/octet-stream';
+ }
+}
\ No newline at end of file
diff --git a/supabase/functions/create-task/index.ts b/supabase/functions/create-task/index.ts
new file mode 100644
index 000000000..eb33b4f74
--- /dev/null
+++ b/supabase/functions/create-task/index.ts
@@ -0,0 +1,179 @@
+import { serve } from "https://deno.land/std@0.224.0/http/server.ts";
+import { createClient } from "https://esm.sh/@supabase/supabase-js@2.39.7";
+
+/**
+ * Edge function: create-task
+ *
+ * Creates a new task in the queue.
+ * - Service-role key: can create tasks for any project_id
+ * - User token: can only create tasks for their own project_id
+ *
+ * POST /functions/v1/create-task
+ * Headers: Authorization: Bearer
+ * Body: { task_id, params, task_type, project_id?, dependant_on? }
+ *
+ * Returns:
+ * - 200 OK with success message
+ * - 401 Unauthorized if no valid token
+ * - 403 Forbidden if token invalid or user not authorized
+ * - 500 Internal Server Error
+ */
+serve(async (req) => {
+ if (req.method !== "POST") {
+ return new Response("Method not allowed", { status: 405 });
+ }
+
+ // ─── 1. Parse body ──────────────────────────────────────────────
+ let body: any;
+ try {
+ body = await req.json();
+ } catch {
+ return new Response("Invalid JSON body", { status: 400 });
+ }
+ const { task_id, params, task_type, project_id, dependant_on } = body;
+ if (!task_id || !params || !task_type) {
+ return new Response("task_id, params, task_type required", { status: 400 });
+ }
+
+ // ─── 2. Extract authorization header ────────────────────────────
+ const authHeader = req.headers.get("Authorization");
+ if (!authHeader?.startsWith("Bearer ")) {
+ return new Response("Missing or invalid Authorization header", { status: 401 });
+ }
+
+ const token = authHeader.slice(7); // Remove "Bearer " prefix
+ const serviceKey = Deno.env.get("SUPABASE_SERVICE_ROLE_KEY");
+ const supabaseUrl = Deno.env.get("SUPABASE_URL");
+
+ if (!serviceKey || !supabaseUrl) {
+ console.error("Missing required environment variables");
+ return new Response("Server configuration error", { status: 500 });
+ }
+
+ // Create admin client for database operations
+ const supabaseAdmin = createClient(supabaseUrl, serviceKey);
+
+ let callerId: string | null = null;
+ let isServiceRole = false;
+
+ // ─── 3. Check if token matches service-role key directly ────────
+ if (token === serviceKey) {
+ isServiceRole = true;
+ console.log("Direct service-role key match");
+ }
+
+ // ─── 4. If not service key, try to decode as JWT and check role ──
+ if (!isServiceRole) {
+ try {
+ const parts = token.split(".");
+ if (parts.length === 3) {
+ // It's a JWT - decode and check role
+ const payloadB64 = parts[1];
+ const padded = payloadB64 + "=".repeat((4 - (payloadB64.length % 4)) % 4);
+ const payload = JSON.parse(atob(padded));
+
+ // Check for service role in various claim locations
+ const role = payload.role || payload.app_metadata?.role;
+ if (["service_role", "supabase_admin"].includes(role)) {
+ isServiceRole = true;
+ console.log("JWT has service-role/admin role");
+ }
+ // Don't extract user ID from JWT - always look it up in user_api_token table
+ }
+ } catch (e) {
+ // Not a valid JWT - will be treated as PAT
+ console.log("Token is not a valid JWT, treating as PAT");
+ }
+ }
+
+ // ─── 5. USER TOKEN PATH - resolve callerId via user_api_token table ──
+ if (!isServiceRole) {
+ console.log("Looking up token in user_api_token table...");
+
+ try {
+ // Query user_api_tokens table to find user
+ const { data, error } = await supabaseAdmin
+ .from("user_api_tokens")
+ .select("user_id")
+ .eq("token", token)
+ .single();
+
+ if (error || !data) {
+ console.error("Token lookup failed:", error);
+ return new Response("Invalid or expired token", { status: 403 });
+ }
+
+ callerId = data.user_id;
+ console.log(`Token resolved to user ID: ${callerId}`);
+ } catch (e) {
+ console.error("Error querying user_api_token:", e);
+ return new Response("Token validation failed", { status: 403 });
+ }
+ }
+
+ // ─── 6. Determine final project_id and validate permissions ─────
+ let finalProjectId: string;
+
+ if (isServiceRole) {
+ // Service role can create tasks for any project_id
+ if (!project_id) {
+ return new Response("project_id required for service role", { status: 400 });
+ }
+ finalProjectId = project_id;
+ console.log(`Service role creating task for project: ${finalProjectId}`);
+ } else {
+ // User token validation
+ if (!callerId) {
+ return new Response("Could not determine user ID", { status: 401 });
+ }
+
+ if (!project_id) {
+ return new Response("project_id required", { status: 400 });
+ }
+
+ // Verify user owns the specified project
+ const { data: projectData, error: projectError } = await supabaseAdmin
+ .from("projects")
+ .select("user_id")
+ .eq("id", project_id)
+ .single();
+
+ if (projectError) {
+ console.error("Project lookup error:", projectError);
+ return new Response("Project not found", { status: 404 });
+ }
+
+ if (projectData.user_id !== callerId) {
+ console.error(`User ${callerId} attempted to create task in project ${project_id} owned by ${projectData.user_id}`);
+ return new Response("Forbidden: You don't own this project", { status: 403 });
+ }
+
+ finalProjectId = project_id;
+ console.log(`User ${callerId} creating task in their owned project ${finalProjectId}`);
+ }
+
+ // ─── 7. Insert row using admin client ───────────────────────────
+ try {
+ const { error } = await supabaseAdmin.from("tasks").insert({
+ id: task_id,
+ params,
+ task_type,
+ project_id: finalProjectId,
+ dependant_on: dependant_on ?? null,
+ status: "Queued",
+ created_at: new Date().toISOString(),
+ });
+
+ if (error) {
+ console.error("create_task error:", error);
+ return new Response(error.message, { status: 500 });
+ }
+
+ console.log(`Successfully created task ${task_id} for project ${finalProjectId} by ${isServiceRole ? 'service-role' : `user ${callerId}`}`);
+ return new Response("Task queued", { status: 200 });
+
+ } catch (error) {
+ console.error("Unexpected error:", error);
+ return new Response(`Internal server error: ${error.message}`, { status: 500 });
+ }
+});
\ No newline at end of file
diff --git a/supabase/functions/create-task/supabase.toml b/supabase/functions/create-task/supabase.toml
new file mode 100644
index 000000000..1b657ac30
--- /dev/null
+++ b/supabase/functions/create-task/supabase.toml
@@ -0,0 +1,2 @@
+[functions."create-task"]
+verify_jwt = false
\ No newline at end of file
diff --git a/supabase/functions/get-completed-segments/index.ts b/supabase/functions/get-completed-segments/index.ts
new file mode 100644
index 000000000..d7e89b777
--- /dev/null
+++ b/supabase/functions/get-completed-segments/index.ts
@@ -0,0 +1,173 @@
+// deno-lint-ignore-file
+// @ts-ignore
+// eslint-disable-next-line @typescript-eslint/no-explicit-any
+declare const Deno: any;
+import { serve } from "https://deno.land/std@0.224.0/http/server.ts";
+import { createClient } from "https://esm.sh/@supabase/supabase-js@2.39.7";
+
+/**
+ * Edge Function: get-completed-segments
+ * Retrieves all completed travel_segment tasks for a given run_id.
+ *
+ * Auth rules:
+ * - Service-role key: full access.
+ * - JWT with service/admin role: full access.
+ * - Personal access token (PAT): must resolve via user_api_tokens and caller must own the project_id supplied.
+ *
+ * Request (POST):
+ * {
+ * "run_id": "string", // required
+ * "project_id": "uuid" // required for PAT / user JWT tokens
+ * }
+ *
+ * Returns 200 with: [{ segment_index, output_location }]
+ */
+
+const corsHeaders = {
+ "Access-Control-Allow-Headers": "authorization, x-client-info, apikey, content-type",
+ "Access-Control-Allow-Origin": "*",
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
+};
+
+serve(async (req) => {
+ if (req.method === "OPTIONS") {
+ return new Response("ok", { headers: corsHeaders });
+ }
+
+ if (req.method !== "POST") {
+ return new Response("Method not allowed", { status: 405 });
+ }
+
+ try {
+ const body = await req.json();
+ const { run_id, project_id } = body;
+
+ if (!run_id) {
+ return new Response("run_id is required", { status: 400 });
+ }
+
+ // ─── Extract & validate Authorization header ──────────────────────────
+ const authHeaderFull = req.headers.get("Authorization");
+ if (!authHeaderFull?.startsWith("Bearer ")) {
+ return new Response("Missing or invalid Authorization header", { status: 401 });
+ }
+ const token = authHeaderFull.slice(7);
+
+ // ─── Environment vars ─────────────────────────────────────────────────
+ const SUPABASE_URL = Deno.env.get("SUPABASE_URL") ?? "";
+ const SERVICE_KEY = Deno.env.get("SUPABASE_SERVICE_ROLE_KEY") ?? "";
+
+ if (!SUPABASE_URL || !SERVICE_KEY) {
+ console.error("SUPABASE_URL or SERVICE_KEY missing in env");
+ return new Response("Server configuration error", { status: 500 });
+ }
+
+ // Admin client (always service role)
+ const supabaseAdmin = createClient(SUPABASE_URL, SERVICE_KEY);
+
+ let isServiceRole = false;
+ let callerId: string | null = null;
+
+ // 1) Direct key match
+ if (token === SERVICE_KEY) {
+ isServiceRole = true;
+ }
+
+ // 2) JWT role check
+ if (!isServiceRole) {
+ try {
+ const parts = token.split(".");
+ if (parts.length === 3) {
+ const payloadB64 = parts[1];
+ const padded = payloadB64 + "=".repeat((4 - (payloadB64.length % 4)) % 4);
+ const payload = JSON.parse(atob(padded));
+ const role = payload.role || payload.app_metadata?.role;
+ if (["service_role", "supabase_admin"].includes(role)) {
+ isServiceRole = true;
+ }
+ }
+ } catch (_) {
+ /* ignore decode errors */
+ }
+ }
+
+ // 3) PAT lookup
+ if (!isServiceRole) {
+ const { data, error } = await supabaseAdmin
+ .from("user_api_tokens")
+ .select("user_id")
+ .eq("token", token)
+ .single();
+
+ if (error || !data) {
+ return new Response("Invalid or expired token", { status: 403 });
+ }
+ callerId = data.user_id;
+ }
+
+ // ─── Authorization for non-service callers ────────────────────────────
+ let effectiveProjectId: string | undefined = project_id;
+
+ if (!isServiceRole) {
+ if (!effectiveProjectId) {
+ return new Response("project_id required for user tokens", { status: 400 });
+ }
+
+ // Ensure caller owns the project
+ const { data: proj, error: projErr } = await supabaseAdmin
+ .from("projects")
+ .select("user_id")
+ .eq("id", effectiveProjectId)
+ .single();
+
+ if (projErr || !proj) {
+ return new Response("Project not found", { status: 404 });
+ }
+ if (proj.user_id !== callerId) {
+ return new Response("Forbidden: You don't own this project", { status: 403 });
+ }
+ }
+
+ // ─── Query completed segments ─────────────────────────────────────────
+ let query = supabaseAdmin
+ .from("tasks")
+ .select("params, output_location")
+ .eq("task_type", "travel_segment")
+ .eq("status", "Complete");
+
+ if (!isServiceRole) {
+ query = query.eq("project_id", effectiveProjectId as string);
+ }
+
+ const { data: rows, error: qErr } = await query;
+ if (qErr) {
+ console.error(qErr);
+ return new Response("Database query error", { status: 500 });
+ }
+
+ const results: { segment_index: number; output_location: string }[] = [];
+ for (const row of rows ?? []) {
+ const paramsObj = typeof row.params === "string" ? JSON.parse(row.params) : row.params;
+ if (
+ paramsObj.orchestrator_run_id === run_id &&
+ typeof paramsObj.segment_index === "number" &&
+ row.output_location
+ ) {
+ results.push({ segment_index: paramsObj.segment_index, output_location: row.output_location });
+ }
+ }
+
+ results.sort((a, b) => a.segment_index - b.segment_index);
+
+ return new Response(JSON.stringify(results), {
+ headers: { ...corsHeaders, "Content-Type": "application/json" },
+ status: 200,
+ });
+ } catch (e) {
+ console.error(e);
+ return new Response(JSON.stringify({ error: (e as Error).message }), {
+ status: 500,
+ headers: { ...corsHeaders, "Content-Type": "application/json" },
+ });
+ }
+});
\ No newline at end of file
diff --git a/supabase/functions/get-completed-segments/supabase.toml b/supabase/functions/get-completed-segments/supabase.toml
new file mode 100644
index 000000000..55c727f00
--- /dev/null
+++ b/supabase/functions/get-completed-segments/supabase.toml
@@ -0,0 +1,2 @@
+[functions.get-completed-segments]
+verify_jwt = false
\ No newline at end of file
diff --git a/supabase/functions/get-predecessor-output/index.ts b/supabase/functions/get-predecessor-output/index.ts
new file mode 100644
index 000000000..871f32093
--- /dev/null
+++ b/supabase/functions/get-predecessor-output/index.ts
@@ -0,0 +1,209 @@
+import { serve } from "https://deno.land/std@0.224.0/http/server.ts";
+import { createClient } from "https://esm.sh/@supabase/supabase-js@2.39.7";
+
+/**
+ * Edge function: get-predecessor-output
+ *
+ * Gets the output location of a task's dependency in a single call.
+ * Combines dependency lookup + output location retrieval.
+ *
+ * POST /functions/v1/get-predecessor-output
+ * Headers: Authorization: Bearer
+ * Body: { task_id: "uuid" }
+ *
+ * Returns:
+ * - 200 OK with { predecessor_id, output_location } or null if no dependency
+ * - 400 Bad Request if task_id missing
+ * - 401 Unauthorized if no valid token
+ * - 403 Forbidden if token invalid or user not authorized
+ * - 404 Not Found if task not found
+ * - 500 Internal Server Error
+ */
+serve(async (req) => {
+ if (req.method !== "POST") {
+ return new Response("Method not allowed", { status: 405 });
+ }
+
+ let body;
+ try {
+ body = await req.json();
+ } catch (e) {
+ return new Response("Invalid JSON body", { status: 400 });
+ }
+
+ const { task_id } = body;
+
+ if (!task_id) {
+ return new Response("task_id is required", { status: 400 });
+ }
+
+ // Extract authorization header
+ const authHeader = req.headers.get("Authorization");
+ if (!authHeader?.startsWith("Bearer ")) {
+ return new Response("Missing or invalid Authorization header", { status: 401 });
+ }
+
+ const token = authHeader.slice(7); // Remove "Bearer " prefix
+ const serviceKey = Deno.env.get("SUPABASE_SERVICE_ROLE_KEY");
+ const supabaseUrl = Deno.env.get("SUPABASE_URL");
+
+ if (!serviceKey || !supabaseUrl) {
+ console.error("Missing required environment variables");
+ return new Response("Server configuration error", { status: 500 });
+ }
+
+ // Create admin client for database operations
+ const supabaseAdmin = createClient(supabaseUrl, serviceKey);
+
+ let callerId: string | null = null;
+ let isServiceRole = false;
+
+ // 1) Check if token matches service-role key directly
+ if (token === serviceKey) {
+ isServiceRole = true;
+ console.log("Direct service-role key match");
+ }
+
+ // 2) If not service key, try to decode as JWT and check role
+ if (!isServiceRole) {
+ try {
+ const parts = token.split(".");
+ if (parts.length === 3) {
+ // It's a JWT - decode and check role
+ const payloadB64 = parts[1];
+ const padded = payloadB64 + "=".repeat((4 - (payloadB64.length % 4)) % 4);
+ const payload = JSON.parse(atob(padded));
+
+ // Check for service role in various claim locations
+ const role = payload.role || payload.app_metadata?.role;
+ if (["service_role", "supabase_admin"].includes(role)) {
+ isServiceRole = true;
+ console.log("JWT has service-role/admin role");
+ }
+ }
+ } catch (e) {
+ // Not a valid JWT - will be treated as PAT
+ console.log("Token is not a valid JWT, treating as PAT");
+ }
+ }
+
+ // 3) USER TOKEN PATH - resolve callerId via user_api_token table
+ if (!isServiceRole) {
+ console.log("Looking up token in user_api_token table...");
+
+ try {
+ const { data, error } = await supabaseAdmin
+ .from("user_api_tokens")
+ .select("user_id")
+ .eq("token", token)
+ .single();
+
+ if (error || !data) {
+ console.error("Token lookup failed:", error);
+ return new Response("Invalid or expired token", { status: 403 });
+ }
+
+ callerId = data.user_id;
+ console.log(`Token resolved to user ID: ${callerId}`);
+ } catch (e) {
+ console.error("Error querying user_api_token:", e);
+ return new Response("Token validation failed", { status: 403 });
+ }
+ }
+
+ try {
+ // Get the task info first
+ const { data: taskData, error: taskError } = await supabaseAdmin
+ .from("tasks")
+ .select("id, dependant_on, project_id")
+ .eq("id", task_id)
+ .single();
+
+ if (taskError) {
+ console.error("Task lookup error:", taskError);
+ return new Response("Task not found", { status: 404 });
+ }
+
+ // Check authorization if not service role
+ if (!isServiceRole && callerId) {
+ console.log(`Verifying task ${task_id} belongs to user ${callerId}...`);
+
+ // Check if user owns the project that this task belongs to
+ const { data: projectData, error: projectError } = await supabaseAdmin
+ .from("projects")
+ .select("user_id")
+ .eq("id", taskData.project_id)
+ .single();
+
+ if (projectError) {
+ console.error("Project lookup error:", projectError);
+ return new Response("Project not found", { status: 404 });
+ }
+
+ if (projectData.user_id !== callerId) {
+ console.error(`Task ${task_id} belongs to project ${taskData.project_id} owned by ${projectData.user_id}, not user ${callerId}`);
+ return new Response("Forbidden: Task does not belong to user", { status: 403 });
+ }
+
+ console.log(`Task ${task_id} authorization verified for user ${callerId}`);
+ }
+
+ // Return the dependency info
+ if (!taskData.dependant_on) {
+ // No dependency
+ return new Response(JSON.stringify({
+ predecessor_id: null,
+ output_location: null
+ }), {
+ status: 200,
+ headers: { "Content-Type": "application/json" }
+ });
+ }
+
+ // Get the predecessor task details
+ const { data: predecessorData, error: predecessorError } = await supabaseAdmin
+ .from("tasks")
+ .select("id, status, output_location")
+ .eq("id", taskData.dependant_on)
+ .single();
+
+ if (predecessorError) {
+ console.error("Predecessor lookup error:", predecessorError);
+ // Dependency exists but predecessor task not found
+ return new Response(JSON.stringify({
+ predecessor_id: taskData.dependant_on,
+ output_location: null,
+ status: "not_found"
+ }), {
+ status: 200,
+ headers: { "Content-Type": "application/json" }
+ });
+ }
+
+ if (predecessorData.status !== "Complete" || !predecessorData.output_location) {
+ // Dependency exists but not complete or no output
+ return new Response(JSON.stringify({
+ predecessor_id: taskData.dependant_on,
+ output_location: null,
+ status: predecessorData.status
+ }), {
+ status: 200,
+ headers: { "Content-Type": "application/json" }
+ });
+ }
+
+ // Dependency is complete with output
+ console.log(`Found predecessor output: ${predecessorData.id} -> ${predecessorData.output_location}`);
+ return new Response(JSON.stringify({
+ predecessor_id: predecessorData.id,
+ output_location: predecessorData.output_location
+ }), {
+ status: 200,
+ headers: { "Content-Type": "application/json" }
+ });
+
+ } catch (error) {
+ console.error("Edge function error:", error);
+ return new Response(`Internal error: ${error.message}`, { status: 500 });
+ }
+});
\ No newline at end of file
diff --git a/supabase/functions/get-predecessor-output/supabase.toml b/supabase/functions/get-predecessor-output/supabase.toml
new file mode 100644
index 000000000..b891ade37
--- /dev/null
+++ b/supabase/functions/get-predecessor-output/supabase.toml
@@ -0,0 +1,2 @@
+[functions.get-predecessor-output]
+verify_jwt = false
\ No newline at end of file
diff --git a/supabase/functions/update-task-status/index.ts b/supabase/functions/update-task-status/index.ts
new file mode 100644
index 000000000..bad355089
--- /dev/null
+++ b/supabase/functions/update-task-status/index.ts
@@ -0,0 +1,218 @@
+import { serve } from "https://deno.land/std@0.224.0/http/server.ts";
+import { createClient } from "https://esm.sh/@supabase/supabase-js@2.39.7";
+
+/**
+ * Edge function: update-task-status
+ *
+ * Updates a task's status and optionally sets output_location.
+ * - Service-role key: can update any task across all users
+ * - User token: can only update tasks for that specific user's projects
+ *
+ * POST /functions/v1/update-task-status
+ * Headers: Authorization: Bearer
+ * Body: {
+ * "task_id": "uuid-string",
+ * "status": "In Progress" | "Failed" | "Complete",
+ * "output_location": "optional-string"
+ * }
+ *
+ * Returns:
+ * - 200 OK with success message
+ * - 400 Bad Request if missing required fields
+ * - 401 Unauthorized if no valid token
+ * - 403 Forbidden if token invalid or user not found
+ * - 404 Not Found if task doesn't exist or user can't access it
+ * - 500 Internal Server Error
+ */
+
+serve(async (req) => {
+ // Only accept POST requests
+ if (req.method !== "POST") {
+ return new Response("Method not allowed", { status: 405 });
+ }
+
+ // Extract authorization header
+ const authHeader = req.headers.get("Authorization");
+ if (!authHeader?.startsWith("Bearer ")) {
+ return new Response("Missing or invalid Authorization header", { status: 401 });
+ }
+
+ const token = authHeader.slice(7); // Remove "Bearer " prefix
+ const serviceKey = Deno.env.get("SUPABASE_SERVICE_ROLE_KEY");
+ const supabaseUrl = Deno.env.get("SUPABASE_URL");
+
+ if (!serviceKey || !supabaseUrl) {
+ console.error("Missing required environment variables");
+ return new Response("Server configuration error", { status: 500 });
+ }
+
+ // Parse request body
+ let requestBody: any = {};
+ try {
+ const bodyText = await req.text();
+ if (bodyText) {
+ requestBody = JSON.parse(bodyText);
+ }
+ } catch (e) {
+ return new Response("Invalid JSON body", { status: 400 });
+ }
+
+ // Validate required fields
+ const { task_id, status } = requestBody;
+ if (!task_id || !status) {
+ return new Response("Missing required fields: task_id and status", { status: 400 });
+ }
+
+ // Validate status values
+ const validStatuses = ["Queued", "In Progress", "Complete", "Failed"];
+ if (!validStatuses.includes(status)) {
+ return new Response(`Invalid status. Must be one of: ${validStatuses.join(", ")}`, { status: 400 });
+ }
+
+ // Create admin client for database operations
+ const supabaseAdmin = createClient(supabaseUrl, serviceKey);
+
+ let callerId: string | null = null;
+ let isServiceRole = false;
+
+ // 1) Check if token matches service-role key directly
+ if (token === serviceKey) {
+ isServiceRole = true;
+ console.log("Direct service-role key match");
+ }
+
+ // 2) If not service key, try to decode as JWT and check role
+ if (!isServiceRole) {
+ try {
+ const parts = token.split(".");
+ if (parts.length === 3) {
+ // It's a JWT - decode and check role
+ const payloadB64 = parts[1];
+ const padded = payloadB64 + "=".repeat((4 - (payloadB64.length % 4)) % 4);
+ const payload = JSON.parse(atob(padded));
+
+ // Check for service role in various claim locations
+ const role = payload.role || payload.app_metadata?.role;
+ if (["service_role", "supabase_admin"].includes(role)) {
+ isServiceRole = true;
+ console.log("JWT has service-role/admin role");
+ }
+ // Don't extract user ID from JWT - always look it up in user_api_token table
+ }
+ } catch (e) {
+ // Not a valid JWT - will be treated as PAT
+ console.log("Token is not a valid JWT, treating as PAT");
+ }
+ }
+
+ // 3) USER TOKEN PATH - ALWAYS resolve callerId via user_api_token table
+ if (!isServiceRole) {
+ console.log("Looking up token in user_api_token table...");
+
+ try {
+ // Query user_api_tokens table to find user
+ const { data, error } = await supabaseAdmin
+ .from("user_api_tokens")
+ .select("user_id")
+ .eq("token", token)
+ .single();
+
+ if (error || !data) {
+ console.error("Token lookup failed:", error);
+ return new Response("Invalid or expired token", { status: 403 });
+ }
+
+ callerId = data.user_id;
+ console.log(`Token resolved to user ID: ${callerId}`);
+ } catch (e) {
+ console.error("Error querying user_api_token:", e);
+ return new Response("Token validation failed", { status: 403 });
+ }
+ }
+
+ try {
+ // Build update payload
+ const updatePayload: any = {
+ status: status,
+ updated_at: new Date().toISOString()
+ };
+
+ // Add optional fields based on status
+ if (status === "In Progress") {
+ updatePayload.generation_started_at = new Date().toISOString();
+ }
+
+ if (requestBody.output_location) {
+ updatePayload.output_location = requestBody.output_location;
+ }
+
+ let updateResult;
+
+ if (isServiceRole) {
+ // Service role: can update any task
+ console.log(`Service role: Updating task ${task_id} to status '${status}'`);
+
+ updateResult = await supabaseAdmin
+ .from("tasks")
+ .update(updatePayload)
+ .eq("id", task_id)
+ .select()
+ .single();
+
+ } else {
+ // User token: can only update tasks in their projects
+ console.log(`User ${callerId}: Updating task ${task_id} to status '${status}'`);
+
+ // First get user's project IDs
+ const { data: userProjects } = await supabaseAdmin
+ .from("projects")
+ .select("id")
+ .eq("user_id", callerId);
+
+ if (!userProjects || userProjects.length === 0) {
+ return new Response("User has no projects", { status: 403 });
+ }
+
+ const projectIds = userProjects.map(p => p.id);
+
+ // Update task only if it belongs to user's projects
+ updateResult = await supabaseAdmin
+ .from("tasks")
+ .update(updatePayload)
+ .eq("id", task_id)
+ .in("project_id", projectIds)
+ .select()
+ .single();
+ }
+
+ if (updateResult.error) {
+ if (updateResult.error.code === "PGRST116") {
+ console.log(`Task ${task_id} not found or not accessible`);
+ return new Response("Task not found or not accessible", { status: 404 });
+ }
+ console.error("Update error:", updateResult.error);
+ return new Response(`Database error: ${updateResult.error.message}`, { status: 500 });
+ }
+
+ if (!updateResult.data) {
+ console.log(`Task ${task_id} not found or not accessible`);
+ return new Response("Task not found or not accessible", { status: 404 });
+ }
+
+ console.log(`Successfully updated task ${task_id} to status '${status}'`);
+
+ return new Response(JSON.stringify({
+ success: true,
+ task_id: task_id,
+ status: status,
+ message: `Task status updated to '${status}'`
+ }), {
+ status: 200,
+ headers: { "Content-Type": "application/json" }
+ });
+
+ } catch (error) {
+ console.error("Unexpected error:", error);
+ return new Response(`Internal server error: ${error.message}`, { status: 500 });
+ }
+});
\ No newline at end of file
diff --git a/tasks/restructure_image_travel.md b/tasks/restructure_image_travel.md
deleted file mode 100644
index d5374cccd..000000000
--- a/tasks/restructure_image_travel.md
+++ /dev/null
@@ -1,132 +0,0 @@
-# Restructuring plan: Travel-Between-Images → Fully-Queued Workflow
-
-This document describes the **structural changes** required to let `steerable_motion.travel_between_images` enqueue _all_ segment tasks at once and let `headless.py` take care of creating guide videos & chaining, instead of the current _serial_ approach.
-
----
-
-## 1. Goals
-
-1. **No per-segment polling in `travel_between_images.py`** – it should finish quickly after pushing **N** segment tasks (and one orchestrator task) into the DB.
-2. **`headless.py` becomes responsible for the dependency chain**:
- • wait for a segment's video to finish
- • create its guide artefacts for the _next_ segment
- • enqueue the next segment automatically
-3. Maintain existing CLI/behaviour for callers (only runtime improves).
-
-
-## 2. High-Level Design
-
-```
-steerable_motion.py headless.py (server loop) wgp.py
-┌──────────────┐ ┌────────────────────────────┐ ┌─────────┐
-│travel_between│ 1. N+1 │ (A) orchestrator task │ │ video │
-│ images │ push → │ (B) segment-0 task │ ----> │gen. │
-└──────────────┘ └────────────┬───────────────┘ └─────────┘
- 2. segment-0 done
- ├─▶ create guide-1
- ├─▶ enqueue segment-1
- └─▶ repeat …
-```
-
-* **Orchestrator task** (new `task_type = "travel_orchestrator"`) carries the full sequence definition (image list, prompts, overlaps, etc.).
-* **Segment task** (`task_type = "travel_segment"`) generates the actual video for a single leg. It stores its output location in DB for successor lookup.
-
-
-## 3. Database / Schema Updates
-
-1. **tasks table** already has `task_type` – good.
-2. Add **`payload` JSONB fields** (optional) if we need to store large sequencing info separate from `params` (optional).
-3. No blocking change if we piggy-back on existing `params`.
-
-
-## 4. Changes in `travel_between_images.py`
-
-### 4.1 Remove per-segment loop / polling
-* Build a single `orchestrator_payload` containing:
- * ordered `input_images`
- * `{base_prompts, negative_prompts, segment_frames, frame_overlap}` (expanded to per-segment arrays so headless has no math to do)
- * common settings from `common_args`
-* Enqueue:
- ```python
- add_task_to_db(orchestrator_payload, db_path, task_type="travel_orchestrator")
- ```
-* **Do _not_** call `poll_task_status`.
-* Exit immediately (status msgs only).
-
-### 4.2 Delete big helper code blocks now destined for `headless.py`:
-* guide-video construction
-* frame extraction helpers
-* cross-fade stitching (these belong to server side now)
-
-> Keep generic utils (e.g. `_get_unique_target_path`) in `sm_functions.common_utils` so both modules can import.
-
-
-## 5. Add logic to `headless.py`
-
-### 5.1 New task handlers
-1. **`_handle_travel_orchestrator_task()`**
- * Reads sequence definition.
- * Immediately enqueues **segment-0** (`travel_segment`, `segment_index=0`, no guide video yet if `continue_from_video` absent).
-2. **`_handle_travel_segment_task()`** (refactor of current generation part)
- * Runs WGP like today (no cross-fade/stitch yet).
- * After completion:
- 1. Save output path to DB (already happens).
- 2. Look inside parent orchestrator payload to determine **if more segments remain**.
- * If yes: create guide video for **next segment** by:
- * extracting overlap frames from the just-rendered mp4 (reuse existing helper).
- * building guide video frames (reuse logic moved from `travel_between_images`).
- * Enqueue next segment task referencing new guide path and `previous_segment_task_id` for dependency clarity.
- 3. If last segment: enqueue **stitch task** (`task_type="travel_stitch"`) that waits for all segment mp4s then cross-fades & (optionally) upscales.
-
-### 5.2 Guide / Stitch utilities
-* Move these from `travel_between_images`:
- * `extract_frames_from_video`, `cross_fade_overlap_frames`, `create_video_from_frames_list`, `_apply_saturation_to_video_ffmpeg`, etc.
-* Place in `sm_functions.common_utils` _or_ a new `sm_functions.video_utils` to keep headless lean.
-
-### 5.3 DB Dependency helpers (optional but nice)
-* Add `depends_on_task_id` column **OR** store list in `params`.
-* When claiming tasks headless should:
- * skip tasks whose `depends_on_task_id` isn't `STATUS_COMPLETE` yet.
- * simple where-clause + ORDER BY is enough.
-
-### 5.4 Polling / Concurrency
-* Because tasks are now independent rows, multiple workers can process parallel segments (except each one waits on previous due to depends_on).
-
-
-## 6. Transitional Considerations
-
-* Provide **feature flag** (e.g. `--legacy_travel_flow`) to fall back to old behaviour until new path is proven.
-* Update unit tests & CI.
-* Ensure migration script adds any new DB columns.
-
-
-## 7. File/Module Moves Summary
-
-| From `travel_between_images.py` | To |
-|--------------------------------------------------------|------------------------------------|
-| `extract_frames_from_video`, `cross_fade_overlap_frames` | `sm_functions.video_utils` |
-| Guide-video build logic (both first & subsequent segs) | `_handle_travel_segment_task` |
-| Final stitching (`create_video_from_frames_list`, etc.)| new `_handle_travel_stitch_task` |
-
-Other helper funcs (`_get_unique_target_path`, `_adjust_frame_brightness`, …) should reside in `common_utils` if not already.
-
-
-## 8. CLI / API Changes
-
-* **No change** for end-user calling `steerable_motion.py travel_between_images …` – all params are forwarded inside orchestrator payload.
-* New internal `task_type` strings: `travel_orchestrator`, `travel_segment`, `travel_stitch`.
-
-
-## 9. Estimated Work Breakdown
-
-1. ✂️ Refactor helpers into shared util module – **½ day**
-2. 📝 Create new handlers in `headless.py` – **1 day**
-3. 🔄 Rewrite `travel_between_images.py` to enqueue orchestrator only – **½ day**
-4. 🔗 Add dependency tracking & DB schema migration – **½ day**
-5. 🧪 Testing (unit + e2e) – **1 day**
-
-_Total: ~3 days of focused work._
-
----
-
-### End of Plan
\ No newline at end of file
diff --git a/tasks/update_wgp_latest_headless.md b/tasks/update_wgp_latest_headless.md
deleted file mode 100644
index 2a2d8a9b7..000000000
--- a/tasks/update_wgp_latest_headless.md
+++ /dev/null
@@ -1,48 +0,0 @@
-# Wan2GP compatibility update – headless.py vs. upstream commit 7670af9
-
-## Context
-The upstream `Wan2GP` repository has been updated (local diff `6706709` → `7670af9`). The main breaking changes impacting `headless.py` are inside `wgp.py`:
-
-* `generate_video(...)` renamed a keyword argument:
- * **`remove_background_image_ref` → `remove_background_images_ref`** (plural)
-* `generate_video(...)` now expects `state["gen"]["file_settings_list"]` to be present (list of per-file settings), similar to existing `file_list`.
-
-Without adapting our wrapper these changes cause `TypeError` (unexpected keyword) or `KeyError`.
-
----
-
-## Required code changes (high-priority)
-
-- [ ] **Update the call into `wgp_mod.generate_video` (`process_single_task`):**
- ```python
- # BEFORE
- remove_background_image_ref = ui_params.get("remove_background_image_ref", 1)
-
- # AFTER (support both keys)
- remove_background_images_ref = ui_params.get("remove_background_images_ref", ui_params.get("remove_background_image_ref", 1))
- ```
-
-- [ ] **Rename the keyword in the actual function call** so we pass `remove_background_images_ref=...`.
-
-- [ ] **Augment the generated state dicts**
- * In `build_task_state(...)` and the minimal state inside `_handle_rife_interpolate_task`, add an empty list entry:
- ```python
- "file_settings_list": []
- ```
- This prevents `KeyError` at the top of the new `generate_video`.
-
-- [ ] **Optionally map legacy task JSON**
- * When building `ui_defaults`, copy any incoming `remove_background_image_ref` value into `remove_background_images_ref` for forward-compat.
-
----
-
-## Nice-to-have / follow-up
-
-- [ ] Confirm no other renamed parameters (scan changed signature lines).
-- [ ] Run an integration test generating a short clip to validate that `headless.py` operates end-to-end after the patch.
-- [ ] Bump internal version / changelog entry: "Compatibility with Wan2GP `7670af9`".
-
----
-
-## Environment / CI notes
-No new environment variables are required; the update is purely within API surface.
\ No newline at end of file