-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathmask_split.py
More file actions
272 lines (215 loc) · 8.98 KB
/
mask_split.py
File metadata and controls
272 lines (215 loc) · 8.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import torch
import torch.nn.functional as F
import cv2
import numpy as np
class MaskSplit:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"mask": ("MASK",),
},
}
RETURN_TYPES = ("IMAGE","MASK")
RETURN_NAMES = ("segmented_images","segmented_masks")
FUNCTION = "segment_mask"
CATEGORY = "CyberEveLoop🐰"
def find_top_left_point(self, mask_np):
"""找到mask中最左上角的点"""
# 找到所有非零点
y_coords, x_coords = np.nonzero(mask_np)
if len(x_coords) == 0:
return float('inf'), float('inf')
# 找到最小x值
min_x = np.min(x_coords)
# 在最小x值的点中找到最小y值
min_y = np.min(y_coords[x_coords == min_x])
return min_x, min_y
def segment_mask(self, mask, image):
"""使用OpenCV快速分割蒙版并处理图像"""
# 保存原始设备信息
device = mask.device if isinstance(mask, torch.Tensor) else torch.device('cpu')
# 确保mask是正确的形状并转换为numpy数组
if isinstance(mask, torch.Tensor):
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
mask_np = (mask[0] * 255).cpu().numpy().astype(np.uint8)
else:
mask_np = (mask * 255).astype(np.uint8)
# 使用OpenCV找到轮廓
contours, hierarchy = cv2.findContours(
mask_np,
cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE
)
mask_info = [] # 用于排序的信息列表
if hierarchy is not None and len(contours) > 0:
hierarchy = hierarchy[0]
contour_masks = {}
# 创建每个轮廓的mask
for i, contour in enumerate(contours):
mask = np.zeros_like(mask_np)
cv2.drawContours(mask, [contour], -1, 255, -1)
contour_masks[i] = mask
# 处理每个轮廓
processed_indices = set()
for i, (contour, h) in enumerate(zip(contours, hierarchy)):
if i in processed_indices:
continue
current_mask = contour_masks[i].copy()
child_idx = h[2]
if child_idx != -1:
while child_idx != -1:
current_mask = cv2.subtract(current_mask, contour_masks[child_idx])
processed_indices.add(child_idx)
child_idx = hierarchy[child_idx][0]
# 找到最左上角的点
min_x, min_y = self.find_top_left_point(current_mask)
# 转换为tensor
mask_tensor = torch.from_numpy(current_mask).float() / 255.0
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = mask_tensor.to(device)
# 保存mask和排序信息
mask_info.append((mask_tensor, min_x, min_y))
processed_indices.add(i)
# 如果没有找到任何轮廓,使用原始mask
if not mask_info:
if isinstance(mask, torch.Tensor):
mask_info.append((mask, 0, 0))
else:
mask_tensor = torch.from_numpy(mask).float()
if len(mask_tensor.shape) == 2:
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = mask_tensor.to(device)
mask_info.append((mask_tensor, 0, 0))
# 根据最左上角点排序
mask_info.sort(key=lambda x: (x[1], x[2]))
# 确保image是正确的形状
if len(image.shape) == 3:
image = image.unsqueeze(0)
# 处理masks和images
result_masks = None
result_images = None
for mask_tensor, _, _ in mask_info:
# 处理masks
if result_masks is None:
result_masks = mask_tensor
else:
result_masks = torch.cat([result_masks, mask_tensor], dim=0)
# 处理images
if result_images is None:
result_images = image.clone()
else:
result_images = torch.cat([result_images, image.clone()], dim=0)
return (result_images, result_masks)
class MaskMerge:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_image": ("IMAGE",),
},
"optional": {
"processed_images": ("IMAGE", {"forceInput": True}),
"masks": ("MASK", {"forceInput": True}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("merged_image",)
FUNCTION = "merge_masked_images"
CATEGORY = "CyberEveLoop🐰"
def standardize_input(self, image, processed_images=None, masks=None):
"""
标准化输入格式
- image: [H,W,C] -> [1,H,W,C]
- processed_images: [...] -> [B,H,W,C]
- masks: [...] -> [B,H,W]
"""
# 处理原始图像
if len(image.shape) == 3:
image = image.unsqueeze(0)
assert len(image.shape) == 4, f"Original image must be 4D [B,H,W,C], got shape {image.shape}"
# 处理processed_images
if processed_images is not None:
if isinstance(processed_images, list):
processed_images = torch.cat(processed_images, dim=0)
if len(processed_images.shape) == 3:
processed_images = processed_images.unsqueeze(0)
assert len(processed_images.shape) == 4, \
f"Processed images must be 4D [B,H,W,C], got shape {processed_images.shape}"
# 处理masks
if masks is not None:
if isinstance(masks, list):
masks = torch.cat(masks, dim=0)
if len(masks.shape) == 2:
masks = masks.unsqueeze(0)
assert len(masks.shape) == 3, f"Masks must be 3D [B,H,W], got shape {masks.shape}"
return image, processed_images, masks
def resize_tensor(self, x, size, mode='bilinear'):
"""调整tensor尺寸的辅助函数"""
# 确保输入是4D tensor [B,C,H,W]
orig_dim = x.dim()
if orig_dim == 3:
x = x.unsqueeze(0)
# 如果是图像 [B,H,W,C],需要转换为 [B,C,H,W]
if x.shape[-1] in [1, 3, 4]:
x = x.permute(0, 3, 1, 2)
# 执行调整
x = F.interpolate(x, size=size, mode=mode, align_corners=False if mode in ['bilinear', 'bicubic'] else None)
# 转换回原始格式
if x.shape[1] in [1, 3, 4]:
x = x.permute(0, 2, 3, 1)
# 如果原始输入是3D,去掉batch维度
if orig_dim == 3:
x = x.squeeze(0)
return x
def merge_masked_images(self, original_image, processed_images=None, masks=None):
"""合并处理后的图像"""
# 确保输入有效
if processed_images is None or masks is None:
return (original_image,)
# 标准化输入
original_image, processed_images, masks = self.standardize_input(
original_image, processed_images, masks
)
# 创建结果图像的副本
result = original_image.clone()
# 获取目标尺寸
target_height = original_image.shape[1]
target_width = original_image.shape[2]
# 调整处理图像的尺寸(如果需要)
if processed_images.shape[1:3] != (target_height, target_width):
processed_images = self.resize_tensor(
processed_images,
(target_height, target_width),
mode='bilinear'
)
# 调整蒙版尺寸(如果需要)
if masks.shape[1:3] != (target_height, target_width):
masks = self.resize_tensor(
masks,
(target_height, target_width),
mode='bilinear'
)
# 扩展蒙版维度以匹配图像通道
masks = masks.unsqueeze(-1).expand(-1, -1, -1, 3)
# 批量处理所有图片
for i in range(processed_images.shape[0]):
current_image = processed_images[i:i+1]
current_mask = masks[i:i+1]
result = current_mask * current_image + (1 - current_mask) * result
assert len(result.shape) == 4, "Output must be 4D [B,H,W,C]"
return (result,)
Mask_CLASS_MAPPINGS = {
"CyberEve_MaskSegmentation": MaskSplit,
"CyberEve_MaskMerge": MaskMerge,
}
Mask_DISPLAY_NAME_MAPPINGS = {
"CyberEve_MaskSegmentation": "Mask Segmentation🐰",
"CyberEve_MaskMerge": "Mask Merge🐰",
}