-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsort.py
More file actions
executable file
·209 lines (167 loc) · 7.17 KB
/
sort.py
File metadata and controls
executable file
·209 lines (167 loc) · 7.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
"""Simple image sorter based on histogram differences."""
import argparse
import os
import shutil
import sys
from concurrent.futures import ProcessPoolExecutor
import cv2
class DefaultOptions:
"""Define and store all import user settings."""
### Common options ###
# Output directory for the sorted images
PATH = os.path.join(os.getcwd(), "sorted")
# Number of bins used for histogram computation
# Low numbers -> less prone to small differences (e.g added background)
# High numbers -> less prone to color shifts or cropping
# Recommended: 10-50
BINS = 10
# Max. number of processes (!) for image reading/histogram calculation
THREADS = 1
# Whether to ignore invalid input files
IGNORE_ERRORS = False
### Advanced Options ###
CHANNELS = [0, 1, 2]
# OpenCV uses 0-179 as Hue range, but 0-255 gives me better results somehow
# Also makes it easier to switch between BGR and HSV input
RANGE = [0, 256, 0, 256, 0, 256]
# OpenCV normalization methods for histogram normalization
# https://docs.opencv.org/master/d2/de8/group__core__array.html#gad12cefbcb5291cf958a85b4b67b6149f
NORM = cv2.NORM_L1
# OpenCV histogram comparison methods
# https://docs.opencv.org/master/d6/dc7/group__imgproc__hist.html#ga994f53817d621e2e4228fc646342d386
# Correlation: 0 or cv2.HISTCMP_CORREL
# Chi-Square: 1 or cv2.HISTCMP_CHISQR
# Intersection: 2 or cv2.HISTCMP_INTERSECT
# Bhattacharyya distance: 3 or cv2.HISTCMP_BHATTACHARYYA
COMP_METHOD = 2
class Image:
"""Hold all image-related information."""
def __init__(self, path):
self.path = os.path.abspath(path)
self.out = None
# BGR->HSV leads to more accurate results in my experience
data = cv2.cvtColor(cv2.imread(self.path), cv2.COLOR_BGR2HSV)
self.hist = cv2.calcHist([data], opts.channels, None,
[opts.bins, opts.bins, opts.bins], opts.range)
# Normalizing the histograms leads to more accurate results
self.hist = cv2.normalize(self.hist, self.hist, norm_type=opts.norm)
def assign_label(self, label):
"""Assemble output path."""
in_name = os.path.basename(self.path)
out_name = f"{label}_{in_name}"
self.out = os.path.join(opts.out_dir, out_name)
def copy(self):
"""Copy file with new label to output directory."""
shutil.copy(self.path, self.out)
def valid_image(img):
"""Check input for existence and support."""
# List of potentially supported image formats (varies across systems)
# https://docs.opencv.org/4.3.0/d4/da8/group__imgcodecs.html#ga288b8b3da0892bd651fce07b3bbd3a56
supported = (
".bmp", ".dib", # Windows bitmaps
".jpeg", ".jpg", ".jpe", # JPEG files
".jp2", # JPEG 2000 files
".png", # Portable Network Graphics
".webp", # WebP
".pbm", ".pgm", ".ppm", ".pxm", ".pnm", # Portable image format
".pfm", # PFM files
".sr", ".ras", # Sun rasters
".tiff", ".tif", # TIFF files
".exr", # OpenEXR Image files
".hdr", ".pic", # Radiance HDR
)
valid = True
ext = os.path.splitext(img)[1]
if not os.path.exists(img):
print(f"{img}: image doesn't exist")
valid = False
elif not (os.path.isfile(img) and ext in supported):
print(f"{img}: invalid input image")
valid = False
if not (valid or opts.ignore_errors):
sys.exit(1)
return valid
def positive_int(string):
"""Convert string provided by parse_cli() to a positive int."""
try:
value = int(string)
if value <= 0:
raise ValueError
except ValueError:
raise argparse.ArgumentTypeError("invalid positive int")
return value
def parse_cli():
"""Parse the command line."""
defaults = DefaultOptions()
parser = argparse.ArgumentParser()
parser.add_argument("images", metavar="image", type=os.path.abspath,
nargs="+", help="input images to sort")
parser.add_argument("-p", "--path", dest="out_dir", type=os.path.abspath,
metavar="PATH", default=defaults.PATH,
help="output directory for sorted images (def: ./sorted)")
parser.add_argument("-b", "--bins", type=positive_int,
metavar="N", default=defaults.BINS,
help="number of bins for histogram computation (def: %(default)s)")
parser.add_argument("-t", "--threads", type=positive_int,
default=defaults.THREADS, metavar="N",
help="max. number of processes to use (def: %(default)s)")
parser.add_argument("-i", "--ignore-errors", action="store_true",
default=defaults.IGNORE_ERRORS,
help="ignore invalid input file errors")
parser.set_defaults(
channels=defaults.CHANNELS,
range=defaults.RANGE,
norm=defaults.NORM,
comp_method=defaults.COMP_METHOD,
)
return parser.parse_args()
def sort(images):
"""Sort images by histogram similarity."""
# Brute-force closest pair of points problem (only with higher score = closer)
# https://en.wikipedia.org/wiki/Closest_pair_of_points_problem
for i in range(len(images)-1):
if opts.comp_method in {0, 2}:
best_score = 0
else:
best_score = float("inf")
best_img = i+1
for j in range(i+1, len(images)):
score = cv2.compareHist(images[i].hist, images[j].hist, opts.comp_method)
if opts.comp_method in {0, 2}:
# Correlation, Intersection -> higher = more similar
if score > best_score:
best_score = score
best_img = j
else:
# Chi-Square, Bhattacharyya distance -> lower = more similar
if score < best_score:
best_score = score
best_img = j
images[i+1], images[best_img] = images[best_img], images[i+1]
return images
def main():
"""Run the main function body."""
images = [i for i in opts.images if valid_image(i)]
if opts.threads > 1:
with ProcessPoolExecutor(max_workers=opts.threads) as executor:
images = [*executor.map(Image, images)]
else:
images = [Image(i) for i in images]
images = sort(images)
# Create output directory
if not os.path.exists(opts.out_dir):
os.mkdir(opts.out_dir)
# Copy opts.images based on their new sorting
label = 1
width = len(str(len(images)))
for img in images:
img.assign_label(f"{label:0{width}}") # Pad sort number with zeros
img.copy()
label += 1
if __name__ == '__main__':
try:
opts = parse_cli()
main()
except KeyboardInterrupt:
print("\nUser interrupt!")