-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathi2v_chain_helper.py
More file actions
133 lines (110 loc) · 4.78 KB
/
i2v_chain_helper.py
File metadata and controls
133 lines (110 loc) · 4.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import numpy as np
from PIL import Image
import torchvision.transforms.v2 as T
from comfy.utils import ProgressBar
def tensor_to_pil(img):
return T.ToPILImage()(img.permute(2, 0, 1)).convert('RGB')
def calculate_ear(eye_landmarks):
"""
Calculate Eye Aspect Ratio (EAR)
For DLib (6 points): 0-3 horizontal, 1-5 and 2-4 vertical
For InsightFace (10 points): we'll take average vertical / horizontal
"""
if len(eye_landmarks) == 6:
# DLib style
p = eye_landmarks
v1 = np.linalg.norm(p[1] - p[5])
v2 = np.linalg.norm(p[2] - p[4])
h = np.linalg.norm(p[0] - p[3])
return (v1 + v2) / (2.0 * h)
elif len(eye_landmarks) == 10:
# InsightFace style (usually perimeter)
# Assuming order is somewhat standard, we'll use bounding box ratio as fallback or pick points
# For simplicity, let's use the bounding box of the eye points
min_p = np.min(eye_landmarks, axis=0)
max_p = np.max(eye_landmarks, axis=0)
width = max_p[0] - min_p[0]
height = max_p[1] - min_p[1]
return height / width if width > 0 else 0
return 0
class I2VChainHelper:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE",),
"analysis_models": ("ANALYSIS_MODELS",),
"min_face_similarity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"min_eyes_openness": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01}),
},
"optional": {
"reference_image": ("IMAGE",),
}
}
RETURN_TYPES = ("IMAGE", "INT", "IMAGE", "IMAGE", "FLOAT", "FLOAT")
RETURN_NAMES = ("trimmed_images", "frame_count", "first_frame", "last_frame", "face_similarity", "eyes_openness")
FUNCTION = "execute"
CATEGORY = "I2VChain"
def execute(self, images, analysis_models, min_face_similarity, min_eyes_openness, reference_image=None):
if images.shape[0] == 0:
return (images, 0, images, images, 0.0, 0.0)
# 1. Get reference embedding
if reference_image is not None and reference_image.shape[0] > 0:
ref_img = tensor_to_pil(reference_image[0])
print("I2VChainHelper: Using provided reference_image for identity.")
else:
ref_img = tensor_to_pil(images[0])
print("I2VChainHelper: No reference_image provided. Using first frame as identity reference.")
ref_embed = analysis_models.get_embeds(np.array(ref_img))
if ref_embed is None:
print("I2VChainHelper: No face detected in reference. Returning empty batch.")
return (images[:0], 0, images[:0], images[:0], 0.0, 0.0)
ref_embed = ref_embed / np.linalg.norm(ref_embed)
last_good_index = 0
final_similarity = 0.0
final_openness = 0.0
pbar = ProgressBar(images.shape[0])
# Scan from last to first to save time as requested
for i in range(images.shape[0] - 1, -1, -1):
img_pil = tensor_to_pil(images[i])
img_np = np.array(img_pil)
# Check similarity
curr_embed = analysis_models.get_embeds(img_np)
if curr_embed is None:
pbar.update(1)
continue
curr_embed = curr_embed / np.linalg.norm(curr_embed)
similarity = np.dot(ref_embed, curr_embed)
if similarity < min_face_similarity:
pbar.update(1)
continue
# Check eyes
landmarks = analysis_models.get_landmarks(img_np)
if landmarks is None:
pbar.update(1)
continue
left_eye = landmarks[3]
right_eye = landmarks[4]
ear_l = calculate_ear(left_eye)
ear_r = calculate_ear(right_eye)
avg_ear = (ear_l + ear_r) / 2.0
if avg_ear < min_eyes_openness:
pbar.update(1)
continue
# Found the last good frame!
last_good_index = i
final_similarity = float(similarity)
final_openness = float(avg_ear)
pbar.update(images.shape[0]) # Mark as complete for the UI
break
trimmed_images = images[:last_good_index + 1]
first_frame = trimmed_images[0:1]
last_frame = trimmed_images[-1:]
return (trimmed_images, trimmed_images.shape[0], first_frame, last_frame, final_similarity, final_openness)
NODE_CLASS_MAPPINGS = {
"I2VChainHelper": I2VChainHelper
}
NODE_DISPLAY_NAME_MAPPINGS = {
"I2VChainHelper": "I2V Chain Helper"
}