-
Notifications
You must be signed in to change notification settings - Fork 193
Open
Description
Hi, thank you for sharing nice programs.
The following changes to utils.py will allow training with fewer classes. Rotating with 4 classes may be especially useful for training with small data.
There seems to be something wrong in visualization (function display_example), but I was able to train four classes.
def angle_difference(x, y, nb_classes = 360):
"""
Calculate minimum difference between two angles.
"""
assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360'
unit_angle = 360 // nb_classes
return 180 - abs(abs(x - y) * unit_angle - 180)
def angle_error(y_true, y_pred):
"""
Calculate the mean diference between the true angles
and the predicted angles. Each angle is represented
as a binary vector.
"""
diff = angle_difference(K.argmax(y_true), K.argmax(y_pred), nb_classes = y_pred.shape[1])
return K.mean(K.cast(K.abs(diff), K.floatx()))
(omitted)
@Class RotNetDataGenerator
def __init__(self, input, input_shape=None, color_mode='rgb', batch_size=64,
one_hot=True, preprocess_func=None, rotate=True, crop_center=False,
crop_largest_rect=False, shuffle=False, seed=None, nb_classes = 360): # nb_classes is added
assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360' # inserted
self.images = None
self.filenames = None
self.input_shape = input_shape
self.color_mode = color_mode
self.batch_size = batch_size
self.one_hot = one_hot
self.preprocess_func = preprocess_func
self.rotate = rotate
self.crop_center = crop_center
self.crop_largest_rect = crop_largest_rect
self.shuffle = shuffle
self.nb_classes = nb_classes # added
self.unit_angle = 360 // nb_classes # added
...
(omitted)
def _get_batches_of_transformed_samples(self, index_array):
# create array to hold the images
batch_x = np.zeros((len(index_array),) + self.input_shape, dtype='float32')
# create array to hold the labels
batch_y = np.zeros(len(index_array), dtype='float32')
# iterate through the current batch
for i, j in enumerate(index_array):
if self.filenames is None:
image = self.images[j]
else:
is_color = int(self.color_mode == 'rgb')
image = cv2.imread(self.filenames[j], is_color)
if is_color:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.rotate:
# get a random angle
rotation_angle = self.unit_angle * np.random.randint(self.nb_classes)
else:
rotation_angle = 0
# generate the rotated image
rotated_image = generate_rotated_image(
image,
rotation_angle,
size=self.input_shape[:2],
crop_center=self.crop_center,
crop_largest_rect=self.crop_largest_rect
)
# add dimension to account for the channels if the image is greyscale
if rotated_image.ndim == 2:
rotated_image = np.expand_dims(rotated_image, axis=2)
# store the image and label in their corresponding batches
batch_x[i] = rotated_image
batch_y[i] = rotation_angle // self.unit_angle
if self.one_hot:
# convert the numerical labels to binary labels
batch_y = to_categorical(batch_y, self.nb_classes) # modified
else:
batch_y /= self.nb_classes
...
(omitted)
def display_examples(model, input, num_images=5, size=None, crop_center=False,
crop_largest_rect=False, preprocess_func=None, save_path=None,
nb_classes = 360): # nb_class was added
"""
Given a model that predicts the rotation angle of an image,
and a NumPy array of images or a list of image paths, display
the specified number of example images in three columns:
Original, Rotated and Corrected.
"""
assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360' # added
unit_angle = 360 // nb_classes # added
if isinstance(input, (np.ndarray)):
images = input
N, h, w = images.shape[:3]
if not size:
size = (h, w)
indexes = np.random.choice(N, num_images)
images = images[indexes, ...]
else:
images = []
filenames = input
N = len(filenames)
indexes = np.random.choice(N, num_images)
for i in indexes:
image = cv2.imread(filenames[i])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
images.append(image)
images = np.asarray(images)
x = []
y = []
for image in images:
rotation_angle = np.random.randint(nb_classes) * unit_angle
rotated_image = generate_rotated_image(
image,
rotation_angle,
size=size,
crop_center=crop_center,
crop_largest_rect=crop_largest_rect
)
x.append(rotated_image)
y.append(rotation_angle // unit_angle)
x = np.asarray(x, dtype='float32')
y = np.asarray(y, dtype='float32')
if x.ndim == 3:
x = np.expand_dims(x, axis=3)
y = to_categorical(y, nb_classes)
x_rot = np.copy(x)
if preprocess_func:
x = preprocess_func(x)
y = np.argmax(y, axis=1)
y_pred = np.argmax(model.predict(x), axis=1)
plt.figure(figsize=(10.0, 2 * num_images))
title_fontdict = {
'fontsize': 14,
'fontweight': 'bold'
}
fig_number = 0
for rotated_image, true_angle, predicted_angle in zip(x_rot, y, y_pred):
true_angle *= unit_angle # added
predicted_angle *= unit_angle # added
...
Metadata
Metadata
Assignees
Labels
No labels