Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ tokio-stream = "0.1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
anyhow = "1.0"
notify = "6.1"
uuid = { version = "1.0", features = ["v4"] }

[build-dependencies]
tonic-build = "0.11"
15 changes: 4 additions & 11 deletions STRUCTURE.tree
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
.
β”œβ”€β”€ ARCHITECTURE.md
β”œβ”€β”€ .benchmarks
β”œβ”€β”€ buf.gen.yaml
β”œβ”€β”€ buf.yaml
β”œβ”€β”€ build.rs
Expand Down Expand Up @@ -40,28 +39,22 @@
β”‚Β Β  └── vtuber_image
β”‚Β Β  └── v1
β”‚Β Β  └── image.proto
β”œβ”€β”€ .pytest_cache
β”‚Β Β  β”œβ”€β”€ CACHEDIR.TAG
β”‚Β Β  β”œβ”€β”€ .gitignore
β”‚Β Β  β”œβ”€β”€ README.md
β”‚Β Β  └── v
β”‚Β Β  └── cache
β”‚Β Β  └── nodeids
β”œβ”€β”€ python
β”‚Β Β  β”œβ”€β”€ comfy_client.py
β”‚Β Β  β”œβ”€β”€ __pycache__
β”‚Β Β  β”‚Β Β  └── comfy_client.cpython-314.pyc
β”‚Β Β  β”œβ”€β”€ requirements.txt
β”‚Β Β  └── test_comfy_client.py
β”œβ”€β”€ README.md
β”œβ”€β”€ ROADMAP.md
β”œβ”€β”€ SECURITY.md
β”œβ”€β”€ src
β”‚Β Β  β”œβ”€β”€ guard
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ cache.rs
β”‚Β Β  β”‚Β Β  └── mod.rs
β”‚Β Β  └── main.rs
β”œβ”€β”€ STRATEGY.md
β”œβ”€β”€ STRUCTURE.tree
β”œβ”€β”€ SUPPORT.md
β”œβ”€β”€ TROUBLESHOOTING.md
└── VISION.md

15 directories, 50 files
11 directories, 47 files
50 changes: 50 additions & 0 deletions deploy/k8s/vtuber-image.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: vtuber-image
namespace: vtuber-image
spec:
replicas: 1
selector:
matchLabels:
app: vtuber-image
template:
metadata:
labels:
app: vtuber-image
spec:
containers:
- name: vtuber-image
image: echo-layer/vtuber-image:latest
ports:
- containerPort: 8083
env:
- name: CONFIG_PATH
value: "/mnt/config"
- name: ALLOWLIST_FILENAME
value: "allowlist.json"
- name: S3_ENDPOINT_URL
value: "http://seaweedfs-s3:8333"
- name: COMFYUI_URL
value: "http://comfyui:8188"
volumeMounts:
- name: config-volume
mountPath: /mnt/config
readOnly: true

# git-sync Sidecar to sync vtuber-commons
- name: git-sync
image: registry.k8s.io/git-sync/git-sync:v4.2.3
args:
- "--repo=https://github.com/echo-layer/vtuber-commons"
- "--branch=main"
- "--root=/mnt/config"
- "--dest=vtuber-commons"
- "--wait=60"
volumeMounts:
- name: config-volume
mountPath: /mnt/config

volumes:
- name: config-volume
emptyDir: {}
48 changes: 48 additions & 0 deletions python/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,45 @@ def upload_result(self, local_filename, target_bucket, target_key):
)
return f"s3://{target_bucket}/{target_key}"

def verify_model(self, model_id, expected_hash, allow_nsfw):
print(f"Verifying model {model_id} on Civitai...", file=sys.stderr)
# Using a timeout to avoid hanging
try:
response = requests.get(f"https://civitai.com/api/v1/models/{model_id}", timeout=10)
if response.status_code != 200:
raise Exception(f"Failed to fetch metadata for model {model_id} from Civitai: {response.status_code}")

metadata = response.json()

# Check NSFW if restricted
if not allow_nsfw and metadata.get('nsfw', False):
raise Exception(f"Model {model_id} is marked as NSFW, but NSFW is not allowed.")

found_hash = False
versions = metadata.get('modelVersions', [])
if not versions:
raise Exception(f"No versions found for model {model_id}")

# We check all versions for the hash to be safe, though usually it's the latest
for version in versions:
for file in version.get('files', []):
hashes = file.get('hashes', {})
sha256 = hashes.get('SHA256')
if sha256:
if sha256.lower() == expected_hash.lower():
found_hash = True
break
if found_hash:
break

if not found_hash:
raise Exception(f"SHA256 hash mismatch for model {model_id}. Expected {expected_hash}")

print(f"Model {model_id} verified successfully.", file=sys.stderr)
return True
except requests.exceptions.RequestException as e:
raise Exception(f"Network error verifying model {model_id}: {str(e)}")

if __name__ == "__main__":
client = ComfyClient()

Expand All @@ -97,6 +136,15 @@ def upload_result(self, local_filename, target_bucket, target_key):
# 2. Inject overrides
workflow = client.inject_overrides(workflow, req.get('overrides', {}))

# 2.5 Verify models
model_auth = req.get('model_auth', [])
for auth in model_auth:
client.verify_model(
auth['model_id'],
auth['expected_hash'],
auth.get('allow_nsfw', False)
)

# 3. Queue prompt
prompt_response = client.queue_prompt(workflow)
prompt_id = prompt_response['prompt_id']
Expand Down
59 changes: 59 additions & 0 deletions python/test_comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,64 @@ def test_upload_result(self, mock_get):
self.assertEqual(kwargs['Key'], "test_key")
self.assertEqual(kwargs['Body'], b"fake_image_data")

@patch('requests.get')
def test_verify_model_success(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"nsfw": False,
"modelVersions": [
{
"files": [
{
"hashes": {
"SHA256": "ABCDEF123456"
}
}
]
}
]
}
mock_get.return_value = mock_response

result = self.client.verify_model("123", "abcdef123456", False)
self.assertTrue(result)

@patch('requests.get')
def test_verify_model_nsfw_rejected(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"nsfw": True,
"modelVersions": []
}
mock_get.return_value = mock_response

with self.assertRaisesRegex(Exception, "Model 123 is marked as NSFW"):
self.client.verify_model("123", "hash", False)

@patch('requests.get')
def test_verify_model_hash_mismatch(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"nsfw": False,
"modelVersions": [
{
"files": [
{
"hashes": {
"SHA256": "WRONGHASH"
}
}
]
}
]
}
mock_get.return_value = mock_response

with self.assertRaisesRegex(Exception, "SHA256 hash mismatch"):
self.client.verify_model("123", "expectedhash", False)

if __name__ == '__main__':
unittest.main()
53 changes: 53 additions & 0 deletions src/guard/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::sync::{Arc, RwLock};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ModelEntry {
pub model_id: String,
pub hash: String,
pub license: String,
pub allow_nsfw: bool,
}

#[derive(Debug, Clone)]
pub struct GuardCache {
cache: Arc<RwLock<HashMap<String, ModelEntry>>>,
}

impl Default for GuardCache {
fn default() -> Self {
Self::new()
}
}

impl GuardCache {
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed

pub fn load_from_file<P: AsRef<Path>>(&self, path: P) -> anyhow::Result<()> {
let content = fs::read_to_string(path)?;
let entries: Vec<ModelEntry> = serde_json::from_str(&content)?;

let mut cache = self
.cache
.write()
.map_err(|_| anyhow::anyhow!("Failed to acquire write lock"))?;
cache.clear();
for entry in entries {
cache.insert(entry.model_id.clone(), entry);
}

Ok(())
}

pub fn get_model(&self, model_id: &str) -> Option<ModelEntry> {
let cache = self.cache.read().ok()?;
cache.get(model_id).cloned()
}
}
1 change: 1 addition & 0 deletions src/guard/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod cache;
Loading
Loading