-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptimize_mask.py
More file actions
64 lines (52 loc) · 2.35 KB
/
optimize_mask.py
File metadata and controls
64 lines (52 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
from concurrent.futures import ProcessPoolExecutor
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from scipy import stats
from scipy.ndimage import label, binary_dilation, generate_binary_structure
from tqdm import tqdm
from utils.utils import encode_seg
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--area_threshold", type=int, default=20)
parser.add_argument("--generated_mask_dir", type=str, default="../dataset/ILGeneration/coco_semantic_light_text2img/label")
parser.add_argument("--save_dir", type=str, default="../dataset/ILGeneration/coco_semantic_light_text2img/optimized_label_p20")
parser.add_argument("--visualize", action="store_true")
parser.add_argument("--num_workers", type=int, default=10)
args = parser.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
return args
def process_mask(mask_name, args):
mask_path = os.path.join(args.generated_mask_dir, mask_name)
mask = np.array(Image.open(mask_path))
if args.visualize:
encoded_mask = encode_seg(mask[np.newaxis, :, :])
plt.imshow(encoded_mask[0])
plt.show()
label_set = np.unique(mask)
label_set = label_set[label_set != 0]
for label_idx in label_set:
label_mask, nums = label(mask == label_idx)
for target_idx in range(1, nums + 1):
target_mask = label_mask == target_idx
if target_mask.sum() < args.area_threshold:
dilated_mask = binary_dilation(target_mask, structure=generate_binary_structure(2, 2))
surrounding_mask = dilated_mask & ~ target_mask
surrounding_pixel = mask[surrounding_mask]
mode = int(stats.mode(surrounding_pixel)[0])
mask[target_mask] = mode
if args.visualize:
encoded_mask = encode_seg(mask[np.newaxis, :, :])
plt.imshow(encoded_mask[0])
plt.show()
else:
Image.fromarray(mask).save(os.path.join(args.save_dir, mask_name))
def optimize_mask(args):
mask_names = os.listdir(args.generated_mask_dir)
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
list(tqdm(executor.map(process_mask, [mask_name for mask_name in mask_names], [args] * len(mask_names)),
total=len(mask_names)))
_args = parse_args()
optimize_mask(_args)