diff --git a/Cargo.toml b/Cargo.toml index aa947ae..54f781b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,8 @@ tokio-stream = "0.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" anyhow = "1.0" +notify = "6.1" +uuid = { version = "1.0", features = ["v4"] } [build-dependencies] tonic-build = "0.11" diff --git a/STRUCTURE.tree b/STRUCTURE.tree index a13d363..6930b2f 100644 --- a/STRUCTURE.tree +++ b/STRUCTURE.tree @@ -1,6 +1,5 @@ . ├── ARCHITECTURE.md -├── .benchmarks ├── buf.gen.yaml ├── buf.yaml ├── build.rs @@ -40,23 +39,17 @@ │   └── vtuber_image │   └── v1 │   └── image.proto -├── .pytest_cache -│   ├── CACHEDIR.TAG -│   ├── .gitignore -│   ├── README.md -│   └── v -│   └── cache -│   └── nodeids ├── python │   ├── comfy_client.py -│   ├── __pycache__ -│   │   └── comfy_client.cpython-314.pyc │   ├── requirements.txt │   └── test_comfy_client.py ├── README.md ├── ROADMAP.md ├── SECURITY.md ├── src +│   ├── guard +│   │   ├── cache.rs +│   │   └── mod.rs │   └── main.rs ├── STRATEGY.md ├── STRUCTURE.tree @@ -64,4 +57,4 @@ ├── TROUBLESHOOTING.md └── VISION.md -15 directories, 50 files +11 directories, 47 files diff --git a/deploy/k8s/vtuber-image.yaml b/deploy/k8s/vtuber-image.yaml new file mode 100644 index 0000000..2ef29b9 --- /dev/null +++ b/deploy/k8s/vtuber-image.yaml @@ -0,0 +1,50 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: vtuber-image + namespace: vtuber-image +spec: + replicas: 1 + selector: + matchLabels: + app: vtuber-image + template: + metadata: + labels: + app: vtuber-image + spec: + containers: + - name: vtuber-image + image: echo-layer/vtuber-image:latest + ports: + - containerPort: 8083 + env: + - name: CONFIG_PATH + value: "/mnt/config" + - name: ALLOWLIST_FILENAME + value: "allowlist.json" + - name: S3_ENDPOINT_URL + value: "http://seaweedfs-s3:8333" + - name: COMFYUI_URL + value: "http://comfyui:8188" + volumeMounts: + - name: config-volume + mountPath: /mnt/config + readOnly: true + + # git-sync Sidecar to sync vtuber-commons + - name: git-sync + image: registry.k8s.io/git-sync/git-sync:v4.2.3 + args: + - "--repo=https://github.com/echo-layer/vtuber-commons" + - "--branch=main" + - "--root=/mnt/config" + - "--dest=vtuber-commons" + - "--wait=60" + volumeMounts: + - name: config-volume + mountPath: /mnt/config + + volumes: + - name: config-volume + emptyDir: {} diff --git a/python/comfy_client.py b/python/comfy_client.py index 99fc0af..e833539 100644 --- a/python/comfy_client.py +++ b/python/comfy_client.py @@ -80,6 +80,45 @@ def upload_result(self, local_filename, target_bucket, target_key): ) return f"s3://{target_bucket}/{target_key}" + def verify_model(self, model_id, expected_hash, allow_nsfw): + print(f"Verifying model {model_id} on Civitai...", file=sys.stderr) + # Using a timeout to avoid hanging + try: + response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10) + if response.status_code != 200: + raise Exception(f"Failed to fetch metadata for model {model_id} from Civitai: {response.status_code}") + + metadata = response.json() + + # Check NSFW if restricted + if not allow_nsfw and metadata.get('nsfw', False): + raise Exception(f"Model {model_id} is marked as NSFW, but NSFW is not allowed.") + + found_hash = False + versions = metadata.get('modelVersions', []) + if not versions: + raise Exception(f"No versions found for model {model_id}") + + # We check all versions for the hash to be safe, though usually it's the latest + for version in versions: + for file in version.get('files', []): + hashes = file.get('hashes', {}) + sha256 = hashes.get('SHA256') + if sha256: + if sha256.lower() == expected_hash.lower(): + found_hash = True + break + if found_hash: + break + + if not found_hash: + raise Exception(f"SHA256 hash mismatch for model {model_id}. Expected {expected_hash}") + + print(f"Model {model_id} verified successfully.", file=sys.stderr) + return True + except requests.exceptions.RequestException as e: + raise Exception(f"Network error verifying model {model_id}: {str(e)}") + if __name__ == "__main__": client = ComfyClient() @@ -97,6 +136,15 @@ def upload_result(self, local_filename, target_bucket, target_key): # 2. Inject overrides workflow = client.inject_overrides(workflow, req.get('overrides', {})) + # 2.5 Verify models + model_auth = req.get('model_auth', []) + for auth in model_auth: + client.verify_model( + auth['model_id'], + auth['expected_hash'], + auth.get('allow_nsfw', False) + ) + # 3. Queue prompt prompt_response = client.queue_prompt(workflow) prompt_id = prompt_response['prompt_id'] diff --git a/python/test_comfy_client.py b/python/test_comfy_client.py index 9a1580b..1a94542 100644 --- a/python/test_comfy_client.py +++ b/python/test_comfy_client.py @@ -45,5 +45,64 @@ def test_upload_result(self, mock_get): self.assertEqual(kwargs['Key'], "test_key") self.assertEqual(kwargs['Body'], b"fake_image_data") + @patch('requests.get') + def test_verify_model_success(self, mock_get): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "nsfw": False, + "modelVersions": [ + { + "files": [ + { + "hashes": { + "SHA256": "ABCDEF123456" + } + } + ] + } + ] + } + mock_get.return_value = mock_response + + result = self.client.verify_model("123", "abcdef123456", False) + self.assertTrue(result) + + @patch('requests.get') + def test_verify_model_nsfw_rejected(self, mock_get): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "nsfw": True, + "modelVersions": [] + } + mock_get.return_value = mock_response + + with self.assertRaisesRegex(Exception, "Model 123 is marked as NSFW"): + self.client.verify_model("123", "hash", False) + + @patch('requests.get') + def test_verify_model_hash_mismatch(self, mock_get): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "nsfw": False, + "modelVersions": [ + { + "files": [ + { + "hashes": { + "SHA256": "WRONGHASH" + } + } + ] + } + ] + } + mock_get.return_value = mock_response + + with self.assertRaisesRegex(Exception, "SHA256 hash mismatch"): + self.client.verify_model("123", "expectedhash", False) + if __name__ == '__main__': unittest.main() diff --git a/src/guard/cache.rs b/src/guard/cache.rs new file mode 100644 index 0000000..8a8fdf4 --- /dev/null +++ b/src/guard/cache.rs @@ -0,0 +1,53 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; +use std::path::Path; +use std::sync::{Arc, RwLock}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ModelEntry { + pub model_id: String, + pub hash: String, + pub license: String, + pub allow_nsfw: bool, +} + +#[derive(Debug, Clone)] +pub struct GuardCache { + cache: Arc>>, +} + +impl Default for GuardCache { + fn default() -> Self { + Self::new() + } +} + +impl GuardCache { + pub fn new() -> Self { + Self { + cache: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub fn load_from_file>(&self, path: P) -> anyhow::Result<()> { + let content = fs::read_to_string(path)?; + let entries: Vec = serde_json::from_str(&content)?; + + let mut cache = self + .cache + .write() + .map_err(|_| anyhow::anyhow!("Failed to acquire write lock"))?; + cache.clear(); + for entry in entries { + cache.insert(entry.model_id.clone(), entry); + } + + Ok(()) + } + + pub fn get_model(&self, model_id: &str) -> Option { + let cache = self.cache.read().ok()?; + cache.get(model_id).cloned() + } +} diff --git a/src/guard/mod.rs b/src/guard/mod.rs new file mode 100644 index 0000000..a5c08fd --- /dev/null +++ b/src/guard/mod.rs @@ -0,0 +1 @@ +pub mod cache; diff --git a/src/main.rs b/src/main.rs index 3bbef6e..6fd45ef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,6 @@ +use notify::{Event, RecursiveMode, Watcher}; use std::io::Write; +use std::path::Path; use std::process::Stdio; use tonic::{transport::Server, Request, Response, Status}; use vtuber_image::v1::image_generator_service_server::{ @@ -6,14 +8,17 @@ use vtuber_image::v1::image_generator_service_server::{ }; use vtuber_image::v1::{GenerateRequest, GenerateResponse}; +pub mod guard; + pub mod vtuber_image { pub mod v1 { tonic::include_proto!("vtuber_image.v1"); } } -#[derive(Default)] -pub struct MyImageGeneratorService {} +pub struct MyImageGeneratorService { + pub guard_cache: guard::cache::GuardCache, +} #[tonic::async_trait] impl ImageGeneratorService for MyImageGeneratorService { @@ -24,20 +29,27 @@ impl ImageGeneratorService for MyImageGeneratorService { let req = request.into_inner(); println!("Received request for persona: {}", req.persona_id); - let overrides = req.overrides.unwrap_or_default(); - let input_json = serde_json::json!({ + // Task 2: Rust gRPC Enforcement + // Placeholder check: Is the persona_id in the allowlist cache? + if self.guard_cache.get_model(&req.persona_id).is_none() { + return Err(Status::permission_denied(format!( + "Requested configuration (persona: {}) is not in the allowlist", + req.persona_id + ))); + } + + let input_payload = serde_json::json!({ "template_bucket": std::env::var("S3_BUCKET_TEMPLATES").unwrap_or_else(|_| "templates".to_string()), "template_key": format!("{}.json", req.persona_id), "overrides": { - "hair_style": overrides.hair_style, - "eye_color": overrides.eye_color, - "outfit": overrides.outfit, + "hair_style": req.overrides.as_ref().map(|o| o.hair_style.clone()).unwrap_or_default(), + "eye_color": req.overrides.as_ref().map(|o| o.eye_color.clone()).unwrap_or_default(), + "outfit": req.overrides.as_ref().map(|o| o.outfit.clone()).unwrap_or_default(), }, "output_bucket": std::env::var("S3_BUCKET_OUTPUTS").unwrap_or_else(|_| "outputs".to_string()), - "output_key": format!("{}/base.png", req.persona_id), + "output_key": format!("{}.png", uuid::Uuid::new_v4()), }); - // Bridge to Python worker with stdin let mut child = std::process::Command::new("python3") .arg("python/comfy_client.py") .stdin(Stdio::piped()) @@ -45,33 +57,22 @@ impl ImageGeneratorService for MyImageGeneratorService { .spawn() .map_err(|e| Status::internal(format!("Failed to spawn python worker: {}", e)))?; - let mut stdin = child - .stdin - .take() - .ok_or_else(|| Status::internal("Failed to open stdin"))?; - - let input_str = input_json.to_string(); - stdin - .write_all(input_str.as_bytes()) - .map_err(|e| Status::internal(format!("Failed to write to stdin: {}", e)))?; - drop(stdin); + let mut stdin = child.stdin.take().expect("Failed to open stdin"); + std::thread::spawn(move || { + stdin + .write_all(input_payload.to_string().as_bytes()) + .expect("Failed to write to stdin"); + }); let output = child .wait_with_output() .map_err(|e| Status::internal(format!("Failed to wait for python worker: {}", e)))?; - if !output.status.success() { - let err = String::from_utf8_lossy(&output.stderr); - return Err(Status::internal(format!("Python worker failed: {}", err))); - } - let stdout = String::from_utf8_lossy(&output.stdout); - let image_url = stdout.lines().last().unwrap_or("").trim().to_string(); - - println!("Generated image URL: {}", image_url); + let last_line = stdout.lines().last().unwrap_or_default(); let reply = GenerateResponse { - image_url, + image_url: last_line.to_string(), metadata: std::collections::HashMap::new(), }; @@ -82,7 +83,60 @@ impl ImageGeneratorService for MyImageGeneratorService { #[tokio::main] async fn main() -> Result<(), Box> { let addr = "[::1]:8083".parse()?; - let generator = MyImageGeneratorService::default(); + + let guard_cache = guard::cache::GuardCache::new(); + + let config_path = std::env::var("CONFIG_PATH").unwrap_or_else(|_| "config".to_string()); + let allowlist_path = Path::new(&config_path).join("allowlist.json"); + + // Ensure config directory exists + if let Some(parent) = allowlist_path.parent() { + std::fs::create_dir_all(parent)?; + } + + // Initial load if exists + if allowlist_path.exists() { + println!("Loading initial allowlist from {:?}", allowlist_path); + if let Err(e) = guard_cache.load_from_file(&allowlist_path) { + eprintln!("Failed to load initial allowlist: {}", e); + } + } else { + println!( + "Allowlist file not found at {:?}, starting with empty cache", + allowlist_path + ); + } + + let cache_clone = guard_cache.clone(); + let allowlist_path_clone = allowlist_path.clone(); + + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + + let mut watcher = notify::recommended_watcher(move |res: notify::Result| { + if let Ok(event) = res { + if event.kind.is_modify() { + let _ = tx.blocking_send(()); + } + } + })?; + + if allowlist_path.exists() { + watcher.watch(&allowlist_path, RecursiveMode::NonRecursive)?; + } else if let Some(parent) = allowlist_path.parent() { + // Watch the parent directory if the file doesn't exist yet + watcher.watch(parent, RecursiveMode::NonRecursive)?; + } + + tokio::spawn(async move { + while rx.recv().await.is_some() { + println!("Allowlist file change detected, reloading..."); + if let Err(e) = cache_clone.load_from_file(&allowlist_path_clone) { + eprintln!("Failed to reload allowlist: {}", e); + } + } + }); + + let generator = MyImageGeneratorService { guard_cache }; println!("ImageGeneratorService server listening on {}", addr);