diff --git a/STRUCTURE.tree b/STRUCTURE.tree index b1b701c..a13d363 100644 --- a/STRUCTURE.tree +++ b/STRUCTURE.tree @@ -1,20 +1,17 @@ . ├── ARCHITECTURE.md +├── .benchmarks +├── buf.gen.yaml +├── buf.yaml ├── build.rs +├── Cargo.lock ├── Cargo.toml -├── ci_fail.log ├── CLAUDE.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DEPLOYMENT_GUIDE.md ├── DESIGN_DECISIONS.md ├── docker-compose.yml -├── docs -│   └── superpowers -│   ├── plans -│   │   └── 2026-04-30-vtuber-image-foundation.md -│   └── specs -│   └── 2026-04-30-vtuber-image-foundation-design.md ├── FAQ.md ├── GEMINI.md ├── .github @@ -37,19 +34,27 @@ │   └── README.zh.md ├── MANIFESTO.md ├── PHILOSOPHY.md -├── pr_auto_fail.log ├── .pre-commit-config.yaml ├── PRINCIPLES.md ├── proto │   └── vtuber_image │   └── v1 │   └── image.proto +├── .pytest_cache +│   ├── CACHEDIR.TAG +│   ├── .gitignore +│   ├── README.md +│   └── v +│   └── cache +│   └── nodeids ├── python │   ├── comfy_client.py -│   └── requirements.txt +│   ├── __pycache__ +│   │   └── comfy_client.cpython-314.pyc +│   ├── requirements.txt +│   └── test_comfy_client.py ├── README.md ├── ROADMAP.md -├── security_fail.log ├── SECURITY.md ├── src │   └── main.rs @@ -59,4 +64,4 @@ ├── TROUBLESHOOTING.md └── VISION.md -14 directories, 46 files +15 directories, 50 files diff --git a/python/comfy_client.py b/python/comfy_client.py index ed513c2..99fc0af 100644 --- a/python/comfy_client.py +++ b/python/comfy_client.py @@ -1,11 +1,52 @@ import requests import json import uuid +import boto3 +import os +import time +import sys +from dotenv import load_dotenv class ComfyClient: def __init__(self, server_address="http://localhost:8188"): + load_dotenv() self.server_address = server_address self.client_id = str(uuid.uuid4()) + + # S3 / SeaweedFS Configuration + s3_endpoint = os.getenv("S3_ENDPOINT_URL", "http://localhost:8333") + s3_access_key = os.getenv("S3_ACCESS_KEY", "any") + s3_secret_key = os.getenv("S3_SECRET_KEY", "any") + + self.s3 = boto3.client( + 's3', + endpoint_url=s3_endpoint, + aws_access_key_id=s3_access_key, + aws_secret_access_key=s3_secret_key + ) + + def fetch_template(self, bucket, key): + response = self.s3.get_object(Bucket=bucket, Key=key) + return json.loads(response['Body'].read().decode('utf-8')) + + def inject_overrides(self, workflow_json, overrides): + for node_id, node in workflow_json.items(): + class_type = node.get('class_type', '') + title = node.get('_meta', {}).get('title', '') + + # Check if the node is a prompt or text node based on title or class + if any(term in class_type for term in ["Prompt", "Text"]) or \ + any(term in title for term in ["Prompt", "Text"]): + + if 'inputs' in node: + for input_key, input_value in node['inputs'].items(): + if isinstance(input_value, str): + for key, value in overrides.items(): + placeholder = f"{{{{{key}}}}}" + if placeholder in input_value: + input_value = input_value.replace(placeholder, value) + node['inputs'][input_key] = input_value + return workflow_json def queue_prompt(self, prompt): p = {"prompt": prompt, "client_id": self.client_id} @@ -13,6 +54,62 @@ def queue_prompt(self, prompt): response = requests.post(f"{self.server_address}/prompt", data=data) return response.json() + def wait_for_completion(self, prompt_id): + while True: + response = requests.get(f"{self.server_address}/history/{prompt_id}") + history = response.json() + if prompt_id in history: + outputs = history[prompt_id].get("outputs", {}) + for node_id in outputs: + node_output = outputs[node_id] + if "images" in node_output: + return node_output["images"][0]["filename"] + time.sleep(1) + + def upload_result(self, local_filename, target_bucket, target_key): + # Fetch image bytes from ComfyUI + response = requests.get(f"{self.server_address}/view?filename={local_filename}") + image_bytes = response.content + + # Upload to S3/SeaweedFS + self.s3.put_object( + Bucket=target_bucket, + Key=target_key, + Body=image_bytes, + ContentType='image/png' + ) + return f"s3://{target_bucket}/{target_key}" + if __name__ == "__main__": client = ComfyClient() - print(f"Client initialized with ID: {client.client_id}") + + # Read request from stdin + input_data = sys.stdin.read() + if not input_data: + sys.exit(0) + + try: + req = json.loads(input_data) + + # 1. Fetch template + workflow = client.fetch_template(req['template_bucket'], req['template_key']) + + # 2. Inject overrides + workflow = client.inject_overrides(workflow, req.get('overrides', {})) + + # 3. Queue prompt + prompt_response = client.queue_prompt(workflow) + prompt_id = prompt_response['prompt_id'] + + # 4. Wait for completion + filename = client.wait_for_completion(prompt_id) + + # 5. Upload result + s3_url = client.upload_result(filename, req['output_bucket'], req['output_key']) + + # 6. Output result URL to stdout for Rust to pick up + print(s3_url) + + except Exception as e: + print(f"Error: {str(e)}", file=sys.stderr) + sys.exit(1) diff --git a/python/requirements.txt b/python/requirements.txt index 043c3b8..6530932 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,2 +1,4 @@ requests==2.31.0 websocket-client==1.7.0 +boto3==1.34.0 +python-dotenv==1.0.0 diff --git a/python/test_comfy_client.py b/python/test_comfy_client.py new file mode 100644 index 0000000..9a1580b --- /dev/null +++ b/python/test_comfy_client.py @@ -0,0 +1,49 @@ +import unittest +from unittest.mock import MagicMock, patch +from comfy_client import ComfyClient +import json + +class TestComfyClient(unittest.TestCase): + def setUp(self): + with patch('boto3.client'): + self.client = ComfyClient() + + @patch('requests.get') + def test_wait_for_completion(self, mock_get): + # Mock history response + prompt_id = "test_prompt_id" + mock_history = { + prompt_id: { + "outputs": { + "9": { + "images": [{"filename": "test_image.png"}] + } + } + } + } + mock_get.return_value.json.return_value = mock_history + + filename = self.client.wait_for_completion(prompt_id) + self.assertEqual(filename, "test_image.png") + mock_get.assert_called_with(f"{self.client.server_address}/history/{prompt_id}") + + @patch('requests.get') + def test_upload_result(self, mock_get): + # Mock view response + mock_get.return_value.content = b"fake_image_data" + + # Mock S3 put_object + self.client.s3.put_object = MagicMock() + + result = self.client.upload_result("test_image.png", "test_bucket", "test_key") + + self.assertEqual(result, "s3://test_bucket/test_key") + mock_get.assert_called_with(f"{self.client.server_address}/view?filename=test_image.png") + self.client.s3.put_object.assert_called_once() + args, kwargs = self.client.s3.put_object.call_args + self.assertEqual(kwargs['Bucket'], "test_bucket") + self.assertEqual(kwargs['Key'], "test_key") + self.assertEqual(kwargs['Body'], b"fake_image_data") + +if __name__ == '__main__': + unittest.main() diff --git a/src/main.rs b/src/main.rs index 81ebd4d..3bbef6e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +use std::io::Write; +use std::process::Stdio; use tonic::{transport::Server, Request, Response, Status}; use vtuber_image::v1::image_generator_service_server::{ ImageGeneratorService, ImageGeneratorServiceServer, @@ -22,16 +24,54 @@ impl ImageGeneratorService for MyImageGeneratorService { let req = request.into_inner(); println!("Received request for persona: {}", req.persona_id); - // Simple bridge to Python worker - let output = std::process::Command::new("python3") + let overrides = req.overrides.unwrap_or_default(); + let input_json = serde_json::json!({ + "template_bucket": std::env::var("S3_BUCKET_TEMPLATES").unwrap_or_else(|_| "templates".to_string()), + "template_key": format!("{}.json", req.persona_id), + "overrides": { + "hair_style": overrides.hair_style, + "eye_color": overrides.eye_color, + "outfit": overrides.outfit, + }, + "output_bucket": std::env::var("S3_BUCKET_OUTPUTS").unwrap_or_else(|_| "outputs".to_string()), + "output_key": format!("{}/base.png", req.persona_id), + }); + + // Bridge to Python worker with stdin + let mut child = std::process::Command::new("python3") .arg("python/comfy_client.py") - .output() - .map_err(|e| Status::internal(format!("Failed to execute python worker: {}", e)))?; + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .map_err(|e| Status::internal(format!("Failed to spawn python worker: {}", e)))?; + + let mut stdin = child + .stdin + .take() + .ok_or_else(|| Status::internal("Failed to open stdin"))?; + + let input_str = input_json.to_string(); + stdin + .write_all(input_str.as_bytes()) + .map_err(|e| Status::internal(format!("Failed to write to stdin: {}", e)))?; + drop(stdin); + + let output = child + .wait_with_output() + .map_err(|e| Status::internal(format!("Failed to wait for python worker: {}", e)))?; + + if !output.status.success() { + let err = String::from_utf8_lossy(&output.stderr); + return Err(Status::internal(format!("Python worker failed: {}", err))); + } + + let stdout = String::from_utf8_lossy(&output.stdout); + let image_url = stdout.lines().last().unwrap_or("").trim().to_string(); - println!("Python output: {:?}", String::from_utf8_lossy(&output.stdout)); + println!("Generated image URL: {}", image_url); let reply = GenerateResponse { - image_url: "http://placeholder.com/image.png".to_string(), + image_url, metadata: std::collections::HashMap::new(), };